Reprinted from | PaperWeekly
©PaperWeekly Original · Author|Li Luoqiu
School|Zhejiang University Master’s Student
Research Direction|Natural Language Processing, Knowledge Graph
Continuing from the previous article, I will record my understanding of the HuggingFace open-source Transformers project code.
This article is based on the Transformers version 4.4.2 (released on March 19, 2021) and analyzes the PyTorch version of BERT related code from the perspective of code structure, specific implementation and principles, and usage, including the following content:
1. BERT Tokenization Model (BertTokenizer)
2. BERT Model (BertModel)
3. 1. BertEmbeddings
2. BertEncoder
3.1. BertLayer
2.1. BertAttention
2.1. BertIntermediate
2. BertOutput
3. BertEmbeddings
4. BertEncoder
3. BERT-based Models
4. BertForPreTraining
5. 1. BertForSequenceClassification
2. BertForMultiChoice
3. BertForTokenClassification
4. BertForQuestionAnswering
5. BERT Training and Optimization
6. BERT Training and Optimization
7. 1. Pre-Training
2. Fine-Tuning
3.1. AdamW
2. Warmup

BERT-based Models
/models/bert/modeling_bert.py
, including BERT pre-training models and BERT classification models. The UML diagram is as follows:
▲ Drawing Tool: Pyreverse
First of all, all the following models are based onBertPreTrainedModel
, which in turn is based on a larger base classPreTrainedModel
. Here, we focus on the functionality ofBertPreTrainedModel
:
-
Used to initialize model weights while maintaining some tagged identities inherited from PreTrainedModel
or class variables when loading the model.
3.1 BertForPreTraining
As we all know, the BERT pre-training tasks include two:
-
Masked Language Model (MLM): Randomly replace some words in the sentence with [MASK]
and then pass the sentence into BERT to encode the information of each word, ultimately using the encoding information of[MASK]
to predict the correct word at that position. This task aims to train the model to understand the meaning of words based on context; -
Next Sentence Prediction (NSP): Input sentence pairs A and B into BERT, using
[CLS]
encoding information to predict whether B is the next sentence of A. This task aims to train the model to understand the relationship between predicted sentences.
In the code, this model that integrates the two tasks isBertForPreTraining
, which contains two components:
BertModel
here has been detailed in the previous article (note that the default setting isadd_pooling_layer=True
, which extracts the output corresponding to[CLS]
for the NSP task), whileBertPreTrainingHeads
is responsible for predicting the two tasks:BertPreTrainingHeads
wrapsBertLMPredictionHead
and a linear layer representing the NSP task. The reason for not encapsulating the NSP task in anotherBertXXXPredictionHead
is probably because it is too simple and unnecessary…Supplement: There is indeed a class encapsulating this, but it is called BertOnlyNSPHead
, which is not used here…
BertPreTrainingHeads
:This class is used to predict the output at the[MASK]
position as a classification output for each word, and note:
-
This class reinitializes a zero vector as the bias for the prediction weights;
-
The output shape of this class is [batch_size, seq_length, vocab_size]
, which predicts the probability values of what category each word in each sentence belongs to (note that softmax is not applied here); -
Another encapsulated class:
BertPredictionHeadTransform
, which is used to perform some linear transformations:
Supplement: I feel this layer could be removed? The output shape has not changed. My personal understanding is that it performs a symmetric operation with the pooling layer, passing through a dense layer before connecting to the classifier…
Back toBertForPreTraining
, let’s continue to see how the two loss components are processed. Its forward propagation differs fromBertModel
by addinglabels
andnext_sentence_label
as inputs:
-
labels
: The shape is[batch_size, seq_length]
, representing the labels for the MLM task. Note that the words that are not masked are set to -100, only the masked words will have their corresponding id, and the task settings are reversed. -
For example, the original sentence is I want to [MASK] an apple
, here I masked the wordeat
when inputting to the model, and the correspondinglabel
is set to[-100, -100, -100, 【eat's corresponding id】, -100, -100]
; -
Why set it to -100 instead of something else? Because
torch.nn.CrossEntropyLoss
defaults toignore_index=-100
, meaning that inputs with labels of 100 will not calculate loss. -
next_sentence_label
: This input is simple, just binary classification labels 0 and 1.
-
BertForMaskedLM
: Pre-training only for the MLM task; -
Based on BertOnlyMLMHead
, which is also another layer encapsulation ofBertLMPredictionHead
; -
BertLMHeadModel
: The difference from the previous one is that this model runs as a decoder version; -
Also based on BertOnlyMLMHead
; -
BertForNextSentencePrediction
: Pre-training only for the NSP task. -
Based on BertOnlyNSPHead
, which is just a linear layer…

-
The input for sentence classification is a sentence (pair), and the output is a single classification label.
BertModel
(with pooling) followed by a dropout and then a linear layer outputting classification:In forward propagation, similar to the previous pre-training model, it requires passinglabels
as input.
-
If initialized with
num_labels=1
, it defaults to a regression task using MSELoss; -
Otherwise, it is considered a classification task.
3.3 BertForMultipleChoice
-
The input for multiple-choice tasks is a set of sentences inputted in batches, and the output is a single label for selecting one sentence.
Structurally similar to sentence classification, except the linear layer output dimension is 1, meaning each sample’s multiple sentences’ outputs need to be concatenated as each sample’s predicted score.
-
In practice, the specific operation is to input multiple sentences together in each batch, so the input processed at once is
[batch_size, num_choices]
of sentences, thus requiring more GPU memory than sentence classification tasks, so caution is needed during training.
3.4 BertForTokenClassification
-
The input for sequence labeling tasks is a single sentence of text, and the output is the category label corresponding to each token.
BertModel
does not include a pooling layer;-
At the same time, the class parameter
_keys_to_ignore_on_load_unexpected
is set to[r"pooler"]
, meaning that during model loading, no error occurs for unnecessary weights.
3.5 BertForQuestionAnswering
This model is used for question-answering tasks, such as SQuAD tasks.
-
The input for question-answering tasks is a question + (for BERT, only one) answer forming a sentence pair, and the output is the starting and ending positions to indicate the specific text in the answer.
This requires two outputs, predicting the starting position and the ending position, both outputs having the same length as the sentence, selecting the maximum predicted value’s index as the predicted position.
-
For illegal labels exceeding the sentence length, they will be clamped (
torch.clamp_
) to a reasonable range.
As a late supplement, let me briefly introduce the ModelOutput
class. It serves as the base class for the outputs of the above models, supporting both dictionary-style access and index-based access, inheriting from Python’s nativeOrderedDict
class.
BERT Training and Optimization
During the pre-training phase, in addition to the well-known 15% and 80% masking ratios, one noteworthy point is parameter sharing.
As for why, it should be because the word_embedding and prediction weights are too large, for example, in bert-base, their size is(30522, 768)
, reducing training difficulty.
4.2 Fine-Tuning
4.2.1 AdamW
First, let’s introduce BERT’s optimizer: AdamW (Adam Weight Decay Optimizer).
This optimizer comes from the Best Paper of ICLR 2017: “Fixing Weight Decay Regularization in Adam”, which proposed a new method to fix the weight decay error in Adam. The paper pointed out that L2 regularization and weight decay are not equivalent in most cases, only equivalent in the case of SGD optimization; and that most frameworks use weight decay for Adam + L2 regularization, which should not be confused.
For analysis of AdamW, you can refer to:
-
AdamW and Super-convergence is now the fastest way to train neural nets [1]
-
paperplanet: It’s 9102 already, stop using Adam + L2 regularization [2]
-
What highlights are worth paying attention to in ICLR 2018?[3]
By the way, this paper “STABLE WEIGHT DECAY REGULARIZATION” seems to complain that AdamW’s Weight Decay implementation still has issues… Need to sort out optimizer-related content when time allows.
Supplement: I have not found a reasonable answer for why this is done, but I found some related discussions: https://forums.fast.ai/t/is-weight-decay-applied-to-the-bias-term/73212/4forums.fast.ai
4.2.2 Warmup
Another characteristic of BERT’s training is Warmup, which means:
-
In the early stages of training, use a smaller learning rate (starting from 0), gradually increasing to the normal size (e.g., 2e-5 above) within a certain number of steps (e.g., 1000 steps), avoiding the model from entering local optima too early and overfitting;
-
In the later stages of training, gradually reduce the learning rate to 0, avoiding significant parameter changes during the later training.
In Huggingface’s implementation, various warmup strategies can be used:
Specifically:
-
CONSTANT: Keep the learning rate fixed; -
CONSTANT_WITH_WARMUP: Linearly adjust the learning rate at each step; -
LINEAR: The two-stage adjustment mentioned above; -
COSINE: Similar to the two-stage adjustment, but uses a trigonometric function curve for adjustment; -
COSINE_WITH_RESTARTS: Repeats the above COSINE adjustment n times during training; -
POLYNOMIAL: Adjusts in two stages according to an exponential curve.
transformers/optimization.py
:-
The most commonly used is get_linear_scheduler_with_warmup
, which is the linear two-stage learning rate adjustment scheme…
Above is the specific implementation analysis of BERT-related code in the Transformers library (version 4.4.2). Feel free to communicate and discuss with readers.
References
Click the card below to follow the public account “Machine Learning Algorithms and Natural Language Processing” for more information: