Detailed Explanation of HuggingFace BERT Source Code

Follow the official account “ML_NLP
Set as “Starred“, heavy content delivered first-hand!

Detailed Explanation of HuggingFace BERT Source Code

Reprinted from | PaperWeekly

©PaperWeekly Original · Author | Li Luoqiu

School | Master’s Student at Zhejiang University

Research Direction | Natural Language Processing, Knowledge Graphs

This article records my understanding of the code in the HuggingFace open-source Transformers project.
As we all know, the BERT model has dominated various benchmarks since its inception in 2018, initiating the paradigm of pre-training + fine-tuning in the NLP field. Now, numerous derivative models of BERT (XL-Net, RoBERTa, ALBERT, ELECTRA, ERNIE, etc.) have emerged, and understanding them can start with the original BERT model.

Detailed Explanation of HuggingFace BERT Source Code

HuggingFace is a chatbot startup based in New York that recognized the trend of BERT early on and began implementing a BERT model based on PyTorch. This project was initially named pytorch-pretrained-bert, which, while reproducing the original effects, also provided easy-to-use methods for various experiments and research based on this powerful model.
As the number of users increased, this project evolved into a larger open-source community, merging various pre-trained language models and adding TensorFlow implementations, and in the second half of 2019, it was renamed Transformers. As of the time of writing this article (March 30, 2021), this project has over 43k stars, making Transformers a de facto basic tool for NLP.
See more: transformers [1]

This article analyzes the code related to BERT in the PyTorch version of the Transformers project version 4.4.2 (released on March 19, 2021) from the perspectives of code structure, specific implementation and principles, as well as usage, including the following content:

1. BERT Tokenization Model (BertTokenizer)

2. BERT Model (BertModel)

2.1 BertEmbeddings

2.2 BertEncoder
2.2.1 BertLayer
2.2.1.1 BertAttention
2.2.1.1.1 BertSelfAttention
2.2.1.1.2 BertSelfOutput
2.2.1.2 BertIntermediate
2.2.1.3 BertOutput
2.2.3 BertPooler
3. BERT-based Models (See next article)
3.1 BertForPreTraining
3.2 BertForSequenceClassification
3.3 BertForMultiChoice
3.4 BertForTokenClassification
3.5 BertForQuestionAnswering
4. BERT Training and Optimization (See next article)
4.1 Pre-Training
4.2 Fine-Tuning
4.2.1 AdamW
4.2.2 Warmup

Detailed Explanation of HuggingFace BERT Source Code

Tokenization (BertTokenizer)
Related to BERT, the Tokenizer is mainly written in

<span>/models/bert/tokenization_bert.py</span> and <span>/models/bert/tokenization_bert_fast.py</span>.

These two codes correspond to the basic <span>BertTokenizer</span> and the <span>BertTokenizerFast</span> that does not perform token to index mapping, focusing mainly on the first one.

class BertTokenizer(PreTrainedTokenizer):
    """
    Construct a BERT tokenizer. Based on WordPiece.

    This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
    Users should refer to this superclass for more information regarding those methods.
    ...
    """
<span>BertTokenizer</span> is a tokenizer based on <span>BasicTokenizer</span> and <span>WordPieceTokenizer</span>:
  • <span>BasicTokenizer</span> is responsible for the first step—splitting sentences by punctuation, spaces, etc., handling whether to lowercase, and cleaning illegal characters.
    • For Chinese characters, it splits by characters through preprocessing (adding spaces);

    • It can also specify certain words not to be split through <span>never_split</span>;

    • This step is optional (default executed).
  • <span>WordPieceTokenizer</span> further decomposes words into subwords on top of words.
    • Subwords are between characters and words, preserving some meaning of words while addressing the vocabulary explosion and OOV (Out-Of-Vocabulary) issues caused by plural forms and tenses in English, by segmenting roots and tense affixes, thus reducing vocabulary and lowering training difficulty;

    • For example, the word tokenizer can be broken down into “token” and “##izer”, where the “##” indicates that the latter word follows the former.

<span>BertTokenizer</span> has the following common methods:

  • <span>from_pretrained</span>: Initializes a tokenizer from a directory containing a vocabulary file (vocab.txt);
  • <span>tokenize</span>: Decomposes text (words or sentences) into a list of subwords;
  • <span>convert_tokens_to_ids</span>: Converts a list of subwords into a list of corresponding indices;
  • <span>convert_ids_to_tokens</span> : The opposite of the previous one;
  • <span>convert_tokens_to_string</span>: Joins the subword list back into a word or sentence by “##”;
  • <span>encode</span>: For a single sentence input, decomposes words and adds special tokens to form the structure “[CLS], x, [SEP]” and converts it to a list of indices corresponding to the vocabulary; for two sentence inputs (multiple sentences only take the first two), decomposes words and adds special tokens to form the structure “[CLS], x1, [SEP], x2, [SEP]” and converts it to a list of indices;
  • <span>decode</span>: Can convert the output of the encode method back into a complete sentence.

Additionally, the class itself has methods:

>>> from transformers import BertTokenizer
>>> bt = BertTokenizer.from_pretrained('./bert-base-uncased/')
>>> bt('I like natural language progressing!')
{'input_ids': [101, 1045, 2066, 3019, 2653, 27673, 999, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
Detailed Explanation of HuggingFace BERT Source Code

Model (BertModel)

The code related to the BERT model is mainly written in <span>/models/bert/modeling_bert.py</span>, which contains over a thousand lines, including the basic structure of the BERT model and fine-tuning models based on it.
Next, let’s analyze the BERT model itself:
class BertModel(BertPreTrainedModel):
    """
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

    To behave as a decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
    set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
    input to the forward pass.
    """ 
BertModel mainly comprises the transformer encoder structure, which includes three parts:
  1. <span>embeddings</span>, which is the entity of the <span>BertEmbeddings</span> class, corresponding to word embeddings;
  2. <span>encoder</span>, which is the entity of the <span>BertEncoder</span> class;
  3. <span>pooler</span>, which is the entity of the <span>BertPooler</span> class, this part is optional.

Note: BertModel can also be configured as a Decoder, but this part is not discussed here.
Next, we will introduce the meanings of various parameters and return values in the forward propagation process of BertModel:
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ): ...
  • <span>input_ids</span>: The list of indices corresponding to the subwords after tokenization;

  • <span>attention_mask</span>: During the self-attention process, this mask is used to distinguish between the subwords in the sentence and padding, filling the padding part with 0;
  • <span>token_type_ids</span>: Marks the current sentence of the subword (first sentence/second sentence/padding);
  • <span>position_ids</span>: Marks the index of the current word in the sentence;
  • <span>head_mask</span>: Used to invalidate certain layers’ attention calculations;
  • <span>inputs_embeds</span>: If provided, then it does not need <span>input_ids</span>, directly enters the Encoder computation as Embedding;

  • <span>encoder_hidden_states</span>: This part is effective when BertModel is configured as a decoder, performing cross-attention instead of self-attention;
  • <span>encoder_attention_mask</span>: Similar to the above, used to mark the padding of the encoder input during cross-attention;
  • <span>past_key_values</span>: This parameter seems to pass the pre-computed K-V product to reduce the overhead of cross-attention (as this part is originally repeated computation);
  • <span>use_cache</span>: Saves the previous parameters and returns them to speed up decoding;
  • <span>output_attentions</span>: Whether to return the attention output of each intermediate layer;
  • <span>output_hidden_states</span>: Whether to return the output of each intermediate layer;
  • <span>return_dict</span>: Whether to return the output in key-value pair form (ModelOutput class, also can be treated as a tuple), defaults to true.

Note: Here, the invalidation of head_mask on attention calculations is different from the pruning of attention heads mentioned later, and simply multiplies the results of certain attention calculations by this coefficient.

The return part is as follows:

        # BertModel's forward propagation return part
        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            last_hidden_state=sequence_output,
            pooler_output=pooled_output,
            past_key_values=encoder_outputs.past_key_values,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
            cross_attentions=encoder_outputs.cross_attentions,
        )
It can be seen that the return value includes not only the outputs of the encoder and pooler but also other specified outputs (hidden_states and attention, etc., which are in<span>encoder_outputs[1:]</span>), making it convenient to use:
        # BertEncoder's forward propagation return part, i.e., the above encoder_outputs
        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

In addition, BertModel also has the following methods, allowing BERT users to perform various operations:

  1. <span>get_input_embeddings</span>: Extracts the word_embeddings from the embedding;
  2. <span>set_input_embeddings</span>: Assigns values to the word_embeddings in the embedding;
  3. <span>_prune_heads</span>: Provides a function to prune attention heads, with input as<span>{layer_num: list of heads to prune in this layer}</span> dictionary, allowing pruning of certain attention heads in specified layers.

Note: Pruning is a complex operation that requires copying the weights of the retained attention heads’ Wq, Kq, Vq and the weights of the fully connected part after concatenation to a new smaller weight matrix (make sure to disable grad before copying), and keeping track of the pruned heads in real time to prevent index errors. For specifics, refer to the <span>BertAttention</span> part’s <span>prune_heads</span> method.
2.1 BertEmbeddings
It consists of three parts summed up:

Detailed Explanation of HuggingFace BERT Source Code

  1. word_embeddings, the embeddings corresponding to the subwords mentioned above.
  2. token_type_embeddings, used to represent the sentence where the current word is located, assisting in distinguishing between sentences and padding, and differences between sentence pairs.
  3. position_embeddings, the position embeddings for each word in the sentence, used to distinguish the order of words. Unlike the design in the transformer paper, this part is trained rather than calculated through a sinusoidal function to obtain fixed embeddings. It is generally believed that this implementation is not conducive to scalability (difficult to transfer directly to longer sentences).

The three embeddings are summed without weights, and output after a LayerNorm + dropout layer, with a size of<span>(batch_size, sequence_length, hidden_size)</span>.

Note: Why use LayerNorm + Dropout here? Why use LayerNorm instead of BatchNorm? You can refer to a good answer:Why does the transformer use layer normalization instead of other normalization methods?[2]
2.2 BertEncoder

It consists of multiple layers of BertLayer, and there is nothing particularly noteworthy here, but one detail is worth noting:

Utilizes gradient checkpointing technology to reduce GPU memory usage during training.

Note: Gradient checkpointing reduces the number of computation graph nodes saved to compress model memory usage, but during gradient computation, it requires recomputing the unsaved values, refer to the paper “Training Deep Nets with Sublinear Memory Cost”, the process is illustrated as follows:
Detailed Explanation of HuggingFace BERT Source Code
In BertEncoder, gradient checkpoint is implemented through torch.utils.checkpoint.checkpoint, which is relatively easy to use, refer to the documentation: torch.utils.checkpoint – PyTorch 1.8.1 documentation [3]

This mechanism’s specific implementation is complex (not fully understood), and will not be elaborated here.

Deeper down, we enter a specific layer of the Encoder:
2.2.1 BertLayer

This layer wraps BertAttention and BertIntermediate + BertOutput (the attention part and the FFN part), as well as the cross-attention part that is ignored here (involving BERT as a Decoder).

Theoretically, calling the three sub-modules in order is sufficient, and there is nothing particularly noteworthy here.

However, another detail appears here:
Detailed Explanation of HuggingFace BERT Source Code
        # This is part of forward
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        # ...
Did you see that <span>apply_chunking_to_forward</span> and <span>feed_forward_chunk</span>? (Why make it so complex? Isn’t it more straightforward to call it directly?)
So what is this <span>apply_chunking_to_forward</span>? Let’s take a closer look:
def apply_chunking_to_forward(
    forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
) -> torch.Tensor:
    """
    This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
    dimension :obj:`chunk_dim`. It then applies a layer :obj:`forward_fn` to each chunk independently to save memory.

    If the :obj:`forward_fn` is independent across the :obj:`chunk_dim` this function will yield the same result as
    directly applying :obj:`forward_fn` to :obj:`input_tensors`.
    ...
    """
It turns out to be a memory-saving technique—wrapping a function that chunks small batches or low-dimensional operations: here, the parameters <span>chunk_size</span> is the size of the chunk, and <span>chunk_dim</span> is the size of the dimension for a single computation, and finally concatenates and returns.

However, in the default operation, these two values are not specifically set (default to 0 and 1 in the source code), so it will directly equate to the normal forward process.

Continuing deeper, we reach the core of the Transformer: the BertAttention part, followed by the FFN part.

2.2.1.1 BertAttention

I thought the attention implementation was here, but it turns out to be another layer down… Here, the self member is the implementation of multi-head attention, while the output member implements a series of operations including fully connected + dropout + residual + LayerNorm after attention.

class BertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

First, let’s return to this layer. Here, the pruning operation mentioned earlier appears, namely the <span>prune_heads</span> method:

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)
The specific implementation of this pruning is summarized as follows:
  • <span>find_pruneable_heads_and_indices</span> is used to locate the heads to be pruned and the indices of the dimensions to be retained;

  • <span>prune_linear_layer</span> is responsible for transferring the weight matrices (along with bias) of Wk/Wq/Wv, retaining the dimensions that have not been pruned to a new matrix.
Next comes the highlight—the specific implementation of Self-Attention.

2.2.1.1.1 BertSelfAttention

Warning: This part can be considered the core area of the model and is the only place involving formulas, so a lot of code will be provided.

Initialization part:

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder
  • In addition to the familiar query, key, value weights and a dropout, there is also a mysterious position_embedding_type, and a decoder flag (of course, I do not intend to introduce the cross-attention part);
  • Note that hidden_size and all_head_size are the same at first. The reason for seemingly unnecessary additional setting of this variable is obviously due to the pruning function above, where after pruning several attention heads, all_head_size naturally becomes smaller;

  • hidden_size must be a multiple of num_attention_heads; for example, in bert-base, each attention head contains 12 heads, and hidden_size is 768, so the size of each head is attention_head_size=768/12=64;

  • What is position_embedding_type? Keep reading to find out…

Next is the key point, which is the forward propagation process.

First, let’s review the basic formulas for multi-head self-attention:
Detailed Explanation of HuggingFace BERT Source Code

Detailed Explanation of HuggingFace BERT Source Code

Where represents the number of attention heads, represents vector concatenation, and .

These attention heads, as we all know, are computed in parallel, so the query, key, and value weights above are unique—not all heads share weights, but are “concatenated”.

Note: The reason for multi-head in the original paper is that Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this. Another reliable analysis is: Why does the Transformer need to perform Multi-head Attention?[4]

Let’s look at the forward method:

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        mixed_query_layer = self.query(hidden_states)

        # ...

Here, the <span>transpose_for_scores</span> is used to reshape hidden_size into the output shape of multiple heads, and transpose the middle two dimensions for matrix multiplication;

Here, the shapes of key_layer/value_layer/query_layer are:<span>(batch_size, num_attention_heads, sequence_length, attention_head_size)</span>;

The shape of attention_scores here is:<span>(batch_size, num_attention_heads, sequence_length, sequence_length)</span>, which conforms to the shape of the attention map obtained through independent calculations of multiple heads.

At this point, the K and Q multiplication to obtain the raw attention scores has been implemented, according to the formula, the next step should be scaling by dk and performing softmax. However—

What appears first is a strange <span>positional_embedding</span>, along with a bunch of Einstein summations:
        # ...
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
        # ...
Note: For the Einstein summation convention, refer to the following documentation: torch.einsum – PyTorch 1.8.1 documentation [5]

Note: The <span>positional_embedding</span> here introduces position embeddings into the attention map—why is this done? I currently don’t understand…

For different <span>positional_embedding_type</span>, there are three operations:

  • <span>absolute</span>: Default value, this part does not need to be processed;
  • <span>relative_key</span>: Processes the key_layer, multiplying it with the <span>positional_embedding</span> and the key matrix;
  • <span>relative_key_query</span>: Processes both key and value by multiplying them with the position encoding.

For now, let’s skip this confusing part and return to the normal attention flow:

        # ...
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask  # Why is this + instead of *?

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        # ...
        return outputs
Major Question: What is the purpose of attention_scores = attention_scores + attention_mask? Shouldn’t it be multiplied by the mask?

Because here, the attention_mask has already been “manipulated”, turning the originally 1 parts into 0, while the originally 0 parts (i.e., padding) into a large negative number, so adding it results in a large negative value:

Why use a “large negative number”? Because this way, after the softmax operation, this term will become a small decimal close to 0.
(Pdb) attention_mask
tensor([[[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],
        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],
        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],
        ...,
        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],
        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]],
        [[[    -0.,     -0.,     -0.,  ..., -10000., -10000., -10000.]]]],
       device='cuda:0')
So, where is this done?
I couldn’t find the answer in <span>modeling_bert.py</span>, but I found a special class in <span>modeling_utils.py</span>: <span>class ModuleUtilsMixin</span>, and in its <span>get_extended_attention_mask</span> method, I discovered the clue:
    def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # ...

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask
So, when is this function called? And how is it related to <span>BertModel</span>?

OK, this involves the inheritance details of BertModel:<span>BertModel</span> inherits from <span>BertPreTrainedModel</span>, which inherits from <span>PreTrainedModel</span>, and <span>PreTrainedModel</span> inherits from <span>[nn.Module, ModuleUtilsMixin, GenerationMixin]</span>three base classes. —A complex encapsulation!

This means that BertModel must call the original attention_mask in some step, invoking get_extended_attention_mask, leading to attention_mask changing from [1, 0] to [0, -1e4] values.
Detailed Explanation of HuggingFace BERT Source Code
Finally, I found this call in the forward propagation process of BertModel (line 944):
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
This method not only changes the value of the mask but also broadcasts it to a shape that can be directly added to the attention map.
HuggingFace, you never disappoint.

Detailed Explanation of HuggingFace BERT Source Code

In addition, noteworthy details include:

  • Scaling is performed based on each head’s dimension, for bert-base, this is 8, which is the square root of 64;
  • attention_probs not only performs softmax but also applies dropout, which might seem unusual, but this is taken from the original Transformer paper;

  • head_mask is the mask for multi-head calculations mentioned earlier; if not set, it defaults to all 1’s, and thus does not take effect here;

  • context_layer is the product of the attention matrix and the value matrix, originally sized as:<span>(batch_size, num_attention_heads, sequence_length, attention_head_size)</span>;

  • After transposing and viewing, the shape of context_layer returns to<span>(batch_size, sequence_length, hidden_size)</span>.
OK, that’s all for attention.
2.2.1.1.2 BertSelfOutput
This part is slightly more complex but not complicated, it is straightforward:
class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

Note: This part also features the combination of LayerNorm and Dropout, but here, dropout occurs first, followed by residual connection and then LayerNorm. The reason for this residual connection is primarily to reduce the training difficulty caused by the depth of the network, making it more sensitive to the original input.

2.2.1.2 BertIntermediate

After BertAttention, there is a fully connected + activation operation:
class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states
  • The fully connected layer performs an expansion; in the case of bert-base, the expansion dimension is 3072, which is four times the original dimension of 768;

Note: Why go through a FFN? I don’t know… Google’s recent paper seems to indicate that models with only attention are ineffective:

Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Deptharxiv.org

  • The activation function here defaults to gelu (Gaussian Error Linear Units (GELUS): ; of course, it cannot be computed directly, and can be approximated with an expression containing tanh (omitted).
For reference (image source from the web):
Detailed Explanation of HuggingFace BERT Source Code

As for why this activation function is used in the transformer…

Detailed Explanation of HuggingFace BERT Source Code
Note: After reviewing some studies, it seems that GeLU performs better than ReLU and others, leading to its adoption in subsequent language models.
2.2.1.3 BertOutput
This part features a fully connected + dropout + LayerNorm, along with a residual connection:
class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
The operations here are identical to those in BertSelfOutput… easily confusing these two components.
The following content includes application models based on BERT, as well as optimizers and usage related to BERT, which will be detailed in the next article.
2.2.3 BertPooler
This layer simply extracts the vector corresponding to the first token of the sentence, i.e., [CLS], then passes it through a fully connected layer and an activation function to output: (this part is optional as pooling can have many different operations)
class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output
Detailed Explanation of HuggingFace BERT Source Code

Takeaways·Summary

In the HuggingFace implementation of the Bert model, various memory-saving techniques are employed:

  • Gradient checkpointing, which does not retain forward propagation nodes, only computes when needed;
  • apply_chunking_to_forward, which computes the FFN part in small batches and low dimensions
BertModel contains complex encapsulation and numerous components. Taking bert-base as an example, the main components are as follows:
  • In total, <span>Dropout</span> appears <span>1+(1+1+1)x12=37</span> times;
  • In total, <span>LayerNorm</span> appears <span>1+(1+1)x12=25</span> times;
  • In total, <span>dense</span> layers appear <span>(1+1+1)x12+1=37</span> times, and not every <span>dense</span> has an accompanying activation function…
BertModel has a massive number of parameters. For bert-base, the parameter count is 109M; the specific calculation process can be referenced: HiroLin: Beginner’s Guide to Bert Series – Parameter Calculation [6].
Detailed Explanation of HuggingFace BERT Source Code

References

Detailed Explanation of HuggingFace BERT Source Code

[1] https://github.com/huggingface/transformers

[2] https://www.zhihu.com/question/395811291/answer/1260290120

[3] https://pytorch.org/docs/stable/checkpoint.html

[4] https://www.zhihu.com/question/341222779/answer/814111138

[5] https://pytorch.org/docs/stable/generated/torch.einsum.html

[6] https://zhuanlan.zhihu.com/p/144582114

Detailed Explanation of HuggingFace BERT Source Code



Recommended Reading:
Visual Enhanced Word Vectors: I Am Word Vectors, I Opened My Eyes!

Transformers Grew Up, What About Its Siblings? (Including Detailed Knowledge Points about Transformers)

ERICA: A Unified Framework for Enhancing Entity and Relation Understanding in Pre-trained Language Models

Click the card below to follow the public account “Machine Learning Algorithms and Natural Language Processing” for more information:

Leave a Comment