Is BERT’s LayerNorm What You Think It Is?

© Author | Wang Kunze

Affiliation | The University of Sydney

Research Direction | NLP

The comparison between Batch Norm and Layer Norm has become a cliché in the field of algorithms. The question of why BERT uses layer norm instead of batch norm has been asked to death, and a casual search on Zhihu reveals many explanations of the differences between BN and LN. Generally, people will provide this diagram:

Is BERT's LayerNorm What You Think It Is?

▲ BN vs LN
People will say that for CV and NLP problems, the information represented by the three dimensions here is different:

Is BERT's LayerNorm What You Think It Is?

If we only look at NLP problems, suppose our batch is (2,3,4), meaning batch_size = 2, seq_length = 3, dim = 4. Assume the first sentence is w1 w2 w3, and the second sentence is w4 w5 w6, then this tensor can be written as:
[[w11, w12, w13, w14], [w21, w22, w23, w24], [w31, w32, w33, w34]
[w41, w42, w43, w44], [w51, w52, w53, w54], [w61, w62, w63, w64]]
We find that if it is BN, it averages the tokens at corresponding positions in the same batch. In other words, (w11+w12+w13+w14+w41+w42+w43+w44)/8 is one of the means, and a total of 3 means will be calculated, which corresponds to C (seq_length) means in the above diagram.
However, if it is LN, it appears to average all features in each sample, i.e., (w11+w12+w13+w14+w21+w22+w23+w24+w31+w32+w33+w34)/12, resulting in a total of 2 means, corresponding to N (batch_size) means in the diagram.
I have always firmly believed that this computation is also the implementation in BERT, but one day I saw @猛猿’s answer in this response: Why does Transformer use LayerNorm? [1] The author provided two diagrams:

Is BERT's LayerNorm What You Think It Is?

▲ Both are Layer Norm but different
The left diagram is consistent with our understanding of LN, which I have always thought was LN, but the right diagram averages over a token. Returning to our original question, for a (2,3,4) tensor, (w11+w12+w13+w14)/4 is one mean, resulting in a total of 2*3=6 means.
So, in BERT, is it the batch_size means (left diagram’s method), or batch_size*seq_length means (right diagram’s method)? We need to look at the source code.
The PyTorch source code for BERT or the transformer encoder is well-known, particularly the built-in transformer encoder from torch and the one written by Hugging Face. Let’s examine them one by one.
# torch.nn.TransformerEncoderLayer
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py
# Line 412
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)

# huggingface bert_model
# https://github.com/huggingface/transformers/blob/3223d49354e41dfa44649a9829c7b09013ad096e/src/transformers/models/bert/modeling_bert.py#L378
# Line 382
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
It can be seen that whether it is the built-in PyTorch or the Hugging Face implementation of the transformer encoder or BERT layer, they all use PyTorch’s own nn.LayerNorm, and the parameters correspond to a hidden dimension of 768 (which the transformer calls d_model and BERT calls hidden_size).
Let’s see what nn.LayerNorm(dim) does. The following code is modified from Understanding torch.nn.LayerNorm in NLP [2]
import torch

batch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)

layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)
print("y: ", layer_norm(embedding))

eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)

print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))
In the above code, I first generated an embedding, then calculated the result after applying nn.LayerNorm(dim) to it, while I manually computed a mean over the last dimension (meaning my mean dimension is 2*3, totaling 6 means). If the results from this computation match the output from nn.LayerNorm(dim), it would indicate that nn.LayerNorm(dim) gives us (batch_size*seq_length) means, aligning with the right diagram’s method. The computed results are as follows:
y:  tensor([[[-0.2500,  1.0848,  0.6808, -1.5156],
         [-1.1630, -0.7052,  1.3840,  0.4843],
         [-1.3510,  0.4520, -0.4354,  1.3345]],

        [[ 0.4372, -0.4610,  1.3527, -1.3290],
         [ 0.2282,  1.3853, -0.2037, -1.4097],
         [-0.9960, -0.6184, -0.0059,  1.6203]]])
mean:  torch.Size([2, 3, 1])
y_custom:  tensor([[[-0.2500,  1.0848,  0.6808, -1.5156],
         [-1.1630, -0.7052,  1.3840,  0.4843],
         [-1.3510,  0.4520, -0.4354,  1.3345]],

        [[ 0.4372, -0.4610,  1.3527, -1.3290],
         [ 0.2282,  1.3853, -0.2037, -1.4097],
         [-0.9960, -0.6184, -0.0059,  1.6203]]])
Indeed, they are consistent, meaning that at least in the built-in torch and Hugging Face’s implementation of BERT, layer norm actually computes the mean for each token’s features separately.
If we want to compute the mean for batch_size as in the left diagram, we just need to modify the parameters of nn.LayerNorm to nn.LayerNorm([seq_size,dim]). The code is as follows, and you can run it to find that this is consistent with computing batch_size means:
import torch

batch_size, seq_size, dim = 2, 3, 4
embedding = torch.randn(batch_size, seq_size, dim)

layer_norm = torch.nn.LayerNorm([seq_size,dim], elementwise_affine = False)
print("y: ", layer_norm(embedding))

eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-2,-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-2,-1), keepdim=True)

print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))
The last question is, if we compute as shown on the right, isn’t it the same as InstanceNorm? I also did a code experiment:
from torch.nn import InstanceNorm2d
instance_norm = InstanceNorm2d(3, affine=False)
x = torch.randn(2, 3, 4)
output = instance_norm(x.reshape(2,3,4,1)) # InstanceNorm2D requires (N,C,H,W) shape as input
print(output.reshape(2,3,4))

layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)
print(layer_norm(x))
You can run it to find that they are indeed consistent.
Conclusion: In BERT’s torch built-in transformer encoder and Hugging Face’s implementation, layer norm essentially performs InstanceNorm.
So, what was the original intention of using layer norm as proposed by Vaswani in “Attention Is All You Need”? The author of tf.tensor2tensor is also Vaswani, so I believe that tf.tensor2tensor should conform to the author’s original source code design. By reviewing the source code (which involves countless files, and everyone can try; there are indeed many, various function encapsulations…), I confirmed that the parameters used for layer norm in the author’s own code are also for the last dimension. This means that the original author essentially also used InstanceNorm.
Lastly, I would like to ask, is InstanceNorm a type of LayerNorm? Why haven’t I seen any related statements?
Is BERT's LayerNorm What You Think It Is?

References

Is BERT's LayerNorm What You Think It Is?
[1] https://www.zhihu.com/question/487766088/answer/2309239401
[2] https://stackoverflow.com/questions/70065235/understanding-torch-nn-layernorm-in-nlp

Further Reading

Is BERT's LayerNorm What You Think It Is?

Is BERT's LayerNorm What You Think It Is?

Is BERT's LayerNorm What You Think It Is?

Is BERT's LayerNorm What You Think It Is?

#Submission Guidelines

Let Your Writing Be Seen by More People

How can we ensure that more quality content reaches the readership with a shorter path, reducing the cost for readers to find quality content? The answer is: people you don’t know.

There are always some people you don’t know who know what you want to know. PaperWeekly might serve as a bridge to facilitate the collision of diverse backgrounds and academic inspirations, sparking more possibilities.

PaperWeekly encourages university labs or individuals to share various quality content on our platform, which can include interpretations of the latest papers, analyses of academic hotspots, research insights, or competition experience explanations. Our only goal is to let knowledge flow.

📝 Basic Submission Requirements:

• Articles must be original works by individuals and not previously published in public channels. If they have been published or are pending publication on other platforms, please clearly indicate.

• Submissions are recommended to be written in markdown format, and images should be sent as attachments, with clear images and no copyright issues.

• PaperWeekly respects the authors’ right to attribution and will provide competitive remuneration for each original article accepted for publication, based on a tiered system according to article views and quality.

📬 Submission Channels:

• Submission Email:[email protected]

• Please include your immediate contact information (WeChat) with your submission, so we can contact the author as soon as the article is selected.

• You can also directly add the editor’s WeChat (pwbot02) for quick submissions, with a note: Name-Submission

Is BERT's LayerNorm What You Think It Is?

△ Long press to add PaperWeekly editor

🔍

Now, you can also find us on “Zhihu”

Search for “PaperWeekly” on Zhihu’s homepage

Click “Follow” to subscribe to our column

·

Is BERT's LayerNorm What You Think It Is?

Leave a Comment