Understanding Attention Mechanism and Its Implementation in PyTorch

Click the blue text aboveComputer Vision Alliance to get more valuable content

Set as favorite in the upper right corner ··· and we won’t miss each other
This is for academic sharing only and does not represent the stance of this public account. Contact for removal in case of infringement
Reprinted from: Author: Lucas

Address: https://www.zhihu.com/people/lucas_zhang

Recommended Series of AI Doctor’s Notes
Zhou Zhihua’s “Machine Learning” hand-drawn notes have been officially open-sourced! Printable version with PDF download link attached

Biomimetic Brain Attention Model -> Resource Allocation

The deep learning attention mechanism is a biomimetic of the human visual attention mechanism, essentially a resource allocation mechanism. The physiological principle is that human visual attention can receive high-resolution information from a certain area in an image while perceiving surrounding areas at lower resolution, and the focus can change over time. In other words, the human eye quickly scans the global image to find the target area that needs attention, then allocates more attention to this area to gain more detailed information and suppress other useless information, thereby improving the efficiency of representation. For example, for the image below, my main focus is on the icon in the middle and the word ATTENTION, while I pay less attention to the stripes on the border, and it can be a bit dizzying to glance at it.

Understanding Attention Mechanism and Its Implementation in PyTorch

Encoder-Decoder Framework == Sequence to Sequence Conditional Generation Framework

The Encoder-Decoder framework, also known as the sequence to sequence conditional generation framework[1], is a research model in the field of text processing. In the conventional encoder-decoder method, the first step is to encode the input sentence sequence X into a fixed-length context vector C through a neural network, which is the semantic representation of the text; the second step is for another neural network acting as a decoder to predict the target word sequence based on the current predicted word and the encoded context vector C. The encoder and decoder’s RNNs are jointly trained, but supervision information only appears on the decoder RNN side, and gradients are backpropagated to the encoder RNN side. Using LSTM for text modeling is currently a popular and effective method[2].
The most typical application of the attention mechanism is in statistical machine translation. Given the task, the input is “Echt”, “Dicke” and “Kiste” into the encoder, using RNN to represent the text as a fixed-length vector h3. However, the problem is that when the decoder generates y1, it only relies on the last hidden state h3, which is the sentence embedding. Thus, this h3 must encode all the information from the input sentence. In fact, the traditional Encoder-Decoder model cannot achieve this function. Isn’t LSTM [3] meant to solve the long-term dependency problem? But in fact, long short-term memory networks still have issues. We say that RNN needs to sequentially pass through all previous units before accessing long-term information at the current processing unit. This means it is prone to the vanishing gradient problem. Then LSTM is introduced, using gating to somewhat solve this issue. Indeed, LSTM, GRU, and their variants can learn a lot of long-term information, but they can at most remember relatively long information, not larger or longer.
Understanding Attention Mechanism and Its Implementation in PyTorch

Using RNN for text representation and generation
So, let’s summarize the general paradigm of traditional encoder-decoder and its issues: the task is to translate the Chinese “我/爱/赛尔” into English. The traditional encoder-decoder first inputs the entire sentence, encoding the last word “赛尔” to finish, and uses RNN to generate a representation vector C for the entire sentence. During conditional generation, when translating to the second word “赛尔”, it needs to step back to find the already predicted h_1 and the context representation C, then decode the output.

From Equal Attention to Focused Attention

In the traditional Encoder-Decoder framework: the decoder predicts the target word sequence based on the current predicted word and the encoded context vector C. This means that regardless of which word is generated, the sentence encoding representation C we use is the same. In other words, any word in the sentence has the same influence on generating a target word P_yi, which means attention is equal. Clearly, this is counterintuitive. The intuition should be: when I translate a certain part, that part should focus on the meaning of the original text being translated. When translating the first word, I should pay more attention to what the first word in the original text means. Please refer to the pseudocode and the diagram below:
P_y1 = F(E<start>,C),P_y2 = F((E<the>,C)P_y3 = F((E<black>,C)
Understanding Attention Mechanism and Its Implementation in PyTorch

RNN performing text translation under the traditional Encoder-Decoder framework, always using the same c
Next, observe the difference between the upper and lower diagrams: the same context representation C will be replaced by a context representation Ci that changes according to the currently generated word.
Understanding Attention Mechanism and Its Implementation in PyTorch

RNN model performing text translation with attention mechanism generates different c at each moment
The text translation process becomes:
P_y1 = F(E<start>,C_0),P_y2 = F((E<the>,C_1)P_y3 = F((E<black>,C_2)Encoder-Decoder framework code implementation[4]
class EncoderDecoder(nn.Module):    """    A standard Encoder-Decoder architecture. Base for this and many    other models.    """    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):        super(EncoderDecoder, self).__init__()        self.encoder = encoder        self.decoder = decoder        self.src_embed = src_embed        self.tgt_embed = tgt_embed        self.generator = generator            def forward(self, src, tgt, src_mask, tgt_mask):        "Take in and process masked src and target sequences."        return self.decode(self.encode(src, src_mask), src_mask,                            tgt, tgt_mask)        def encode(self, src, src_mask):        return self.encoder(self.src_embed(src), src_mask)        def decode(self, memory, src_mask, tgt, tgt_mask):        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
Considering Interpretability
The traditional encoder-decoder without an attention model has poor interpretability: we do not have a clear understanding of what information is encoded in the encoding vector, how to utilize this information, and the reasons for the specific behavior of the decoder. Structures that include attention mechanisms provide a relatively straightforward way for us to understand the reasoning process of the decoder and what the model is learning, and what it has learned. Although it is a form of weak interpretability, it has already made sense.

Facing the Core Formula of Attention

Understanding Attention Mechanism and Its Implementation in PyTorch
When predicting the i-th word in the target language, the weight of the j-th word in the source language is Understanding Attention Mechanism and Its Implementation in PyTorch, the size of the weight can be seen as a form of soft alignment information between the source language and the target language.

Summary

Using the attention method is essentially about automatically retrieving semantic information from different positions in the original sentence when predicting a target word yi, and assigning a weight to the semantic information at each position, which is the “soft” alignment information, organizing this information to compute the original sentence vector representation c_i for the current word yi.
Attention's PyTorch Application Implementation
import torchimport torch.nn as nnclass BiLSTM_Attention(nn.Module):    def __init__(self):        super(BiLSTM_Attention, self).__init__()        self.embedding = nn.Embedding(vocab_size, embedding_dim)        self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)        self.out = nn.Linear(n_hidden * 2, num_classes)    # lstm_output : [batch_size, n_step, n_hidden * num_directions(=2)], F matrix    def attention_net(self, lstm_output, final_state):        hidden = final_state.view(-1, n_hidden * 2, 1)   # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]        soft_attn_weights = F.softmax(attn_weights, 1)        # [batch_size, n_hidden * num_directions(=2), n_step] * [batch_size, n_step, 1] = [batch_size, n_hidden * num_directions(=2), 1]        context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)        return context, soft_attn_weights.data.numpy() # context : [batch_size, n_hidden * num_directions(=2)]    def forward(self, X):        input = self.embedding(X) # input : [batch_size, len_seq, embedding_dim]        input = input.permute(1, 0, 2) # input : [len_seq, batch_size, embedding_dim]        hidden_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]        cell_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]        # final_hidden_state, final_cell_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]        output, (final_hidden_state, final_cell_state) = self.lstm(input, (hidden_state, cell_state))        output = output.permute(1, 0, 2) # output : [batch_size, len_seq, n_hidden]        attn_output, attention = self.attention_net(output, final_hidden_state)        return self.out(attn_output), attention # model : [batch_size, num_classes], attention : [batch_size, n_step]
Welcome to star!
https://github.com/zy1996code/nlp_basic_model/blob/master/lstm_attention.py

References

1.“Neural Network Methods in Natural Language Processing”
2.“Sequence to Sequence Learning with Neural Networks” https://arxiv.org/pdf/1409.3215.pdf
3.“LSTM Demystified: Understanding Long Short-Term Memory Networks and Their PyTorch Implementation” https://zhuanlan.zhihu.com/p/86876988
4.“The Annotated Transformer” https://nlp.seas.harvard.edu/2018/04/03/attention.html

end

This is my personal WeChat, and there are a few spots available for communication and learning with related scholars and researchers
Currently offering artificial intelligence, machine learning, computer vision, autonomous driving (including SLAM), Python, job interview experiences, and comprehensive exchange groups Scan to add the CV Alliance WeChat to get you into the group, note: CV Alliance
Understanding Attention Mechanism and Its Implementation in PyTorch
Follow Wang Bo’s public account for more valuable content
Understanding Attention Mechanism and Its Implementation in PyTorch
Wang Bo’s series of hand-drawn notes (with high-definition PDF download):

Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 1 Mind Map

Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 2 “Model Evaluation and Selection”
Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 3 “Linear Models”
Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 4 “Decision Trees”
Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 5 “Neural Networks”
Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 6 Support Vector Machines (Part 1)
Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 6 Support Vector Machines (Part 2)
Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 7 Bayesian Classification (Part 1)
Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 7 Bayesian Classification (Part 2)
Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 8 (Part 1)
Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 8 (Part 2)

Doctor’s Notes | Zhou Zhihua’s “Machine Learning” hand-drawn notes Chapter 9

Click to support if you like itUnderstanding Attention Mechanism and Its Implementation in PyTorchUnderstanding Attention Mechanism and Its Implementation in PyTorch

Leave a Comment