A few days ago, while training a new Transformer model, I found that no matter how I trained it, it wouldn’t converge. After some debugging, I discovered that during Self Attention, I forgot to divide by, so I revisited why dividing by is so important. Of course, Google’s T5 does not divide by, but it can still converge normally because it made some adjustments in its initialization strategy, so this matter is also related to initialization.
Taking this opportunity, this article will sort out the contents of model initialization, parameterization, and normalization with everyone, with relevant discussions mainly centered around the Transformer.
1
『Sampling Distribution』
Initialization is naturally random sampling, so let’s first introduce some commonly used sampling distributions. Generally, we sample from a specified mean and variance random distribution for initialization. Among them, three common random distributions are: Normal distribution, Uniform distribution, and Truncated Normal distribution.
Clearly, both normal and uniform distributions are very common. The normal distribution is usually denoted as where is the mean and is the variance; the uniform distribution on the interval is generally denoted as, with mean, variance, so if specifying mean and variance, the corresponding uniform distribution is.
Generally speaking, the sampling results of the normal distribution are more diverse, but it is theoretically unbounded, and if the sampled result has an absolute value that is too large, it may hinder optimization; conversely, the uniform distribution is bounded, but the sampling results are usually more singular. Thus, the “Truncated Normal Distribution” emerged, which specifies both mean and variance while also requiring a specified interval, sampling from, and if the sampled result is in, then the result is retained; otherwise, resampling is repeated until the sampled result falls within.
In TensorFlow’s built-in<span class="language-python">tf.random.truncated_normal</span>
, it is hardcoded. Therefore, based on the formula, we can calculate that the actual mean of the sampling result of this function is still, but the actual variance is, where:
If we want to obtain sampling results with variance, then the standard deviation passed into the function should be.
2
『Stable Second Moment』
In a previous article “Understanding Model Parameter Initialization Strategies from a Geometric Perspective” (https://kexue.fm/archives/7180), I analyzed existing initialization methods from a geometric perspective. The general idea is that a specific random matrix approximates an orthogonal matrix, thus ensuring the stability of the model in the initial phase. However, although the geometric perspective has intuitive advantages, it is usually difficult to generalize and expand, so next, we will understand initialization methods from an algebraic perspective.
In general tutorials, the idea of deriving initialization methods is to make the input and output have the same mean and variance as much as possible. Typically, we assume the input is a random vector with a mean of 0 and a variance of 1, and then try to make the output’s mean be 0 and variance be 1. However, I believe this is actually unnecessary, and for some non-negative activation functions, it is fundamentally impossible to achieve a mean of 0. In fact, as long as the second (origin) moment of the input and output of each layer can remain unchanged, then during backpropagation, the gradients of each layer of the model will also remain within a certain range at the origin, neither exploding nor vanishing, so this model can essentially be stably trained.
Now, let’s consider a fully connected layer without activation functions (assuming the number of input nodes is and the number of output nodes is)
For simplicity, we usually initialize the bias term to all zeros and also set the mean of to 0, which helps simplify the results below, but it is not strictly necessary; it is just a relatively straightforward choice. We calculate the second moment:
Note that are independent and identically distributed, so at that time, we only need to consider the situation. Assuming the second moment of the input is 1, then
So to make it equal to 1, considering the assumption of the mean being 0, we get the initialization strategy of “sampling independently from a random distribution with mean 0 and variance”, which is Xavier initialization. Note that during this process, we did not make any assumptions about the mean of the input, so it can even be all non-negative.
3
『Activation Function』
Of course, this is only in the case of no activation function; if we consider activation functions, then specific situations need specific analysis. For example, if the activation function is, we can assume that roughly half of the values are set to zero, so the estimate of the second moment is half of:
Thus, the variance for the initialization that keeps the second moment unchanged is, which is specifically designed for networks called He initialization.
However, if the activation function is equal, then the analysis becomes more complicated; if the activation function is, then it is fundamentally impossible to find any initialization that can make the second moment equal to 1. In this case, if we still want to keep the second moment unchanged, then one possible solution is to “tweak the definition of the activation function”.
For example, assume the input has a mean of 0 and a variance of 1, and we still use “mean 0, variance” initialization, then the output before activation is also mean 0 and variance 1, so we can estimate the second moment after activation:
In other words, under this assumption, the second moment after activation of the model is roughly. Therefore, if we want to maintain the second moment of the output approximately unchanged, we can divide the output result by, in other words, the activation function is changed from to, which is the “tweaked” activation function. If you think it’s necessary, you can also change the output mean to 0 by subtracting a constant.
Back in 2017, there was a “sensational” paper “Self-Normalizing Neural Networks” (https://arxiv.org/abs/1706.02515) that proposed an activation function, which was also based on the same idea of being a “tweaked” function, with the form as follows:
It initially “made waves” because it claimed to achieve automatic normalization of networks without using techniques like Batch Normalization, and because its accompanying dozens of pages of mathematical derivation were quite “impressive”. However, from the perspective above, it simply introduces two parameters to adjust the function so that when the input is a standard normal distribution, the mean and variance of the output activation value are 0 and 1, respectively, thus it can be considered a relatively good initialization, and therefore it can only make a “temporary splash”. The two parameters can also be numerically solved using Mathematica:
f[x_] = Exp[-x^2/2]/Sqrt[2 Pi];
s[x_] = Piecewise[{{
Lambda*x,
x > 0}, {
Lambda*Alpha*(Exp[x] - 1), x <= 0}}];
x1 = Integrate[f[x]*s[x], {x, -Infinity, Infinity}];
x2 = Integrate[f[x]*s[x]^2, {x, -Infinity, Infinity}];
N[Solve[{x1 == 0, x2 == 1}, {
Lambda,
Alpha}], 20]
4
『Direct Normalization』
Of course, compared to this simple “tweaking”, a more direct approach is various Normalization methods, such as Batch Normalization, Instance Normalization, Layer Normalization, etc. These methods directly calculate the mean and variance of the current data to standardize the output results without prior estimation of integrals; sometimes we also refer to them as “normalization”. These three normalization methods are generally similar, except that Batch Normalization includes an additional step of sliding average prediction for the mean and variance; they only differ in the dimension of normalization, for example, Layer Normalization is used more in NLP, especially in Transformer models:
Other descriptions will not be repeated here. For interested readers regarding the principles of how these methods work, you can refer to my previous article “What Role Does BN Really Play? An Analysis from a Closed-Door Perspective” (https://kexue.fm/archives/6992).
Here I found an interesting phenomenon: Normalization generally includes two parts: mean subtraction (center) and division by standard deviation (scale), but some recent works have gradually attempted to remove the center step, and some results even show that removing the center step slightly improves performance.
For example, the 2019 paper “Root Mean Square Layer Normalization” (https://arxiv.org/abs/1910.07467) compared Layer Normalization without the center step, calling it RMS Norm, with the following form:
It can be seen that RMS Norm is merely a simple variant of L2 Normalization, but the overall results of this paper show: RMS Norm is faster than Layer Normalization, and the effects are basically the same.
Besides this paper, RMS Norm was also used by Google in T5, and in another paper “Do Transformer Modifications Transfer Across Implementations and Applications?” (https://arxiv.org/abs/2102.11972), a thorough comparative experiment was conducted, demonstrating the superiority of RMS Norm. Thus, it seems likely that RMS Norm will replace Layer Normalization and become standard in Transformers.
Coincidentally, the 2019 paper “Analyzing and Improving the Image Quality of StyleGAN” (https://arxiv.org/abs/1912.04958) proposed an improved version of StyleGAN, StyleGAN2, which found that the use of Instance Normalization caused some generated images to exhibit “water droplets”. They ultimately removed Instance Normalization and replaced it with something called “Weight Demodulation”, but they also discovered that removing the center operation from Instance Normalization could alleviate this phenomenon. This also provides evidence that the center operation in normalization may bring negative effects.
An intuitive guess is that the center operation, similar to the bias term in fully connected layers, stores a prior distribution information about the data, and directly storing this prior distribution information in the model may lead to a decrease in the model’s transfer ability. Therefore, T5 not only removed the center operation of Layer Normalization but also eliminated the bias terms from every layer.
5
『NTK Parameterization』
Returning to Xavier initialization of fully connected layers, it suggests that we should initialize using a “random distribution with mean 0 and variance”. However, besides directly using this initialization method, we can also have another parameterization approach: initializing with a “random distribution with mean 0 and variance 1”, but dividing the output result by, i.e., the model becomes:
This is known as “NTK parameterization” in Gaussian processes, with reference papers including “Neural Tangent Kernel: Convergence and Generalization in Neural Networks”, “On the Infinite Width Limit of Neural Networks with a Standard Parameterization”, etc. However, for me, the first time I encountered this operation was in the PGGAN paper “Progressive Growing of GANs for Improved Quality, Stability, and Variation”.
Neural Tangent Kernel: Convergence and Generalization in Neural Networks
https://arxiv.org/abs/1806.07572
On the Infinite Width Limit of Neural Networks with a Standard Parameterization
https://arxiv.org/abs/2001.07301
Progressive Growing of GANs for Improved Quality, Stability, and Variation
https://arxiv.org/abs/1710.10196
Clearly, using NTK parameterization, we can initialize all parameters with standard variance while still maintaining the second moment unchanged, and even the “tweaked activation function” mentioned earlier can be seen as a kind of NTK parameterization. A natural question arises: What benefits does NTK parameterization have compared to directly using Xavier initialization?
Theoretically, there is a slight advantage. After using NTK parameterization, all parameters can be initialized with a variance of 1 distribution, which means that the magnitude of each parameter is roughly at the same level. Therefore, we can set a larger learning rate, for example, and if using adaptive optimizers, the update amount is roughly so we know that the learning rate’s adjustment to parameters is roughly at this level. In summary, NTK parameterization allows us to handle each parameter more equitably and provides a clearer understanding of the training update magnitude, enabling better parameter adjustments.
At this point, we can discuss the initial question of this article: why is dividing by so important in Attention? For two-dimensional vectors, assuming they are both sampled from a “mean 0, variance 1” distribution, then the second moment of their inner product is:
That is, the second moment of the inner product is, and since the mean is also 0, this also means that the variance is. Attention is the inner product followed by softmax, the main operation designed is, we can roughly consider that the values after the inner product and before softmax are in this range, since they are usually at least 64, thus very large and very small, so after softmax, the distribution of Attention is very close to a one-hot distribution, which leads to serious gradient vanishing problems, resulting in poor training performance.
Accordingly, there are two solutions: one is to divide by after the inner product, making the variance equal to 1, ensuring that they do not become too large or too small, so that after softmax, it does not become one-hot and leads to gradient vanishing, which is also the approach used in conventional Transformers like BERT’s Self Attention; the other is not to divide, but when initializing the fully connected layer, its initialization variance should be divided by an additional factor, which is the approach adopted by T5.
6
『Residual Connections』
Finally, we must discuss the related design of residual connections. It can be easily proven that if the variance (similarly the second moment) of is and the variance of is, and assuming the two are independent, then the variance of is. This means that residuals further amplify the variance, so we also need to consider strategies to reduce its variance.
A relatively naive solution is to directly add a normalization operation after the residual:
This can be referred to as the Post Norm structure, which is also the design used in the original Transformer and BERT. However, although this approach stabilizes the forward propagation variance, it has severely weakened the residual itself, thus losing the advantage of the residual being “easy to train”, often requiring warmup and setting a sufficiently small learning rate to ensure convergence.
How to understand this? Assuming the initial state has a variance of 1, then the variance of becomes 2, while the normalization operation is responsible for reducing the variance back to 1, indicating that the initial phase of Post Norm is equivalent to
Recursively, we get
Do you see the problem? Originally, the purpose of the residual was to create a “green channel” for the previous layers, allowing gradients to be more directly backpropagated, but in Post Norm, this “green channel” is severely weakened, meaning that the closer it is to the front, the smaller its weight becomes, making the residual “exist in name only”, thus still difficult to train. Relevant analysis can also be found in the paper “On Layer Normalization in the Transformer Architecture” (https://arxiv.org/abs/2002.04745).
A targeted improvement is called Pre Norm, which means “normalize only when needed”, with the form:
Similarly, after iteratively expanding, we can consider that in the initial phase, there is
This way, at least every residual channel is equally weighted, making the effect of the residual more significant than in Post Norm, thus optimizing better. Of course, the final variance will be large, so normalization should also be added before the prediction layer.
In my opinion, neither Post Norm nor Pre Norm is perfect, as they both fail to maintain an identity function in the initial phase. In my view, the most elegant method should introduce an initially zero scalar parameter, so that
and then gradually update it. This way, in the initial phase, we can ensure the model is an identity function, thus avoiding variance issues. This technique later appeared in two papers, in “Batch Normalization Biases Residual Blocks Towards the Identity Function in Deep Networks” it was called SkipInit, and in “ReZero is All You Need: Fast Convergence at Large Depth” it was called ReZero. The results of both papers, which were published less than a month apart, showed that this approach could basically replace the normalization operation in the residual.
Batch Normalization Biases Residual Blocks Towards the Identity Function in Deep Networks
https://arxiv.org/abs/2002.10444
ReZero is All You Need: Fast Convergence at Large Depth
https://arxiv.org/abs/2003.04887
For the update of, whether SkipInit or ReZero, it is treated as a model parameter updated together with other model parameters. I initially thought this way too. Later I found that the status of is not equivalent to other parameters and cannot be generalized; for example, through the previously introduced NTK parameterization, we can use a large learning rate for other parameters, but clearly, we should not use a large learning rate for . Additionally, we know that if training is successful, both Post Norm and Pre Norm perform well (correspondingly), so the choice of this residual mode is purely an initialization issue rather than a fitting ability issue. Considering these points, I later simply allowed to increase slowly with a fixed, very small step size until it reached a fixed value, which in my experimental results achieved optimal results.
Scan the QR code to add the assistant’s WeChat