Understanding Attention Mechanism and Its Implementation in PyTorch

Biomimetic Brain Attention Model -> Resource Allocation

The deep learning attention mechanism is a biomimetic of the human visual attention mechanism, essentially a resource allocation mechanism. The physiological principle is that human visual attention can receive high-resolution information from a specific area in an image while perceiving its surrounding areas at a lower resolution, and the focus can change over time. In other words, the human eye quickly scans the global image to find the target area that requires attention and allocates more attention to that area to gain more detailed information while suppressing other irrelevant information, thus improving the efficiency of representation. For example, in the image below, my main focus is on the icon in the middle and the word ATTENTION, while I pay little attention to the stripes on the border, which makes me feel a bit dizzy.

Understanding Attention Mechanism and Its Implementation in PyTorch

Encoder-Decoder Framework == Sequence to Sequence Conditional Generation Framework

The Encoder-Decoder framework, also known as the sequence to sequence conditional generation framework, is a research paradigm in the field of text processing. In the conventional encoder-decoder approach, the first step is to encode the input sentence sequence X into a fixed-length context vector C through a neural network, which represents the semantic meaning of the text. The second step involves another neural network acting as a decoder, which predicts the target word sequence based on the context vector C and the words already predicted. The encoder and decoder RNNs are trained jointly, but supervised information only appears on the decoder RNN side, with gradients backpropagating to the encoder RNN side. Using LSTM for text modeling is currently a popular and effective method.[2].
The most typical application of the attention mechanism is statistical machine translation. Given a task, the input is “Echt”, “Dicke” and “Kiste” into the encoder, using RNN to represent the text as a fixed-length vector h3. However, the problem is that when the decoder generates y1, it relies solely on the last hidden state h3, which is the sentence embedding. This means that h3 must encode all the information in the input sentence. In reality, traditional Encoder-Decoder models cannot achieve this functionality. Isn’t LSTM [3] meant to solve long-term dependency issues? However, LSTM still has problems. We say that RNN needs to sequentially pass through all previous units before accessing long-term information from the current processing unit, which means it is prone to the vanishing gradient problem. LSTM is introduced to somewhat mitigate this issue. Indeed, LSTM, GRU, and their variants can learn a large amount of long-term information, but they can only remember relatively long information, not larger and longer.
Understanding Attention Mechanism and Its Implementation in PyTorch

Using RNN for text representation and generation
Therefore, let’s summarize the general paradigm of traditional encoder-decoder and its issues: the task is to translate the Chinese “我/爱/赛尔” to English. The traditional encoder-decoder first inputs the entire sentence, encodes the last word “赛尔”, and after finishing, uses RNN to generate a representation vector C for the entire sentence. During conditional generation, when translating the second word “赛尔”, it needs to step back to find the already predicted h_1 and the context representation C, and then decode the output.

From Equal Attention to Focused Attention

In the traditional Encoder-Decoder framework: the decoder predicts the target word sequence based on the context vector C that has been encoded. This means that regardless of which word is generated, the sentence encoding representation C we use is the same. In other words, any word in the sentence has the same influence on generating a target word P_yi, which means attention is uniform. Clearly, this contradicts intuition. Intuitively, when I translate a certain part, that part should receive focused attention from the original text. When translating the first word, I should pay more attention to what the first word in the original text means. See the pseudocode and the diagram below:
P_y1 = F(E<start>,C),P_y2 = F((E<the>,C)P_y3 = F((E<black>,C)
Understanding Attention Mechanism and Its Implementation in PyTorch

RNN performing text translation under the traditional Encoder-Decoder framework, continuously using the same c
Next, observe the differences between the upper and lower diagrams: the same context representation C will be replaced by a context that changes according to the currently generated word, Ci.
Understanding Attention Mechanism and Its Implementation in PyTorch

The RNN model with the attention mechanism for text translation generates different c at each moment
The text translation process becomes:
P_y1 = F(E<start>,C_0),P_y2 = F((E<the>,C_1)P_y3 = F((E<black>,C_2)Encoder-Decoder framework code implementation[4]
class EncoderDecoder(nn.Module):    """    A standard Encoder-Decoder architecture. Base for this and many    other models.    """    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):        super(EncoderDecoder, self).__init__()        self.encoder = encoder        self.decoder = decoder        self.src_embed = src_embed        self.tgt_embed = tgt_embed        self.generator = generator            def forward(self, src, tgt, src_mask, tgt_mask):        "Take in and process masked src and target sequences."        return self.decode(self.encode(src, src_mask), src_mask,                            tgt, tgt_mask)        def encode(self, src, src_mask):        return self.encoder(self.src_embed(src), src_mask)        def decode(self, memory, src_mask, tgt, tgt_mask):        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
Considering Interpretability
Traditional encoder-decoder without attention mechanisms has poor interpretability: we do not have a clear understanding of what information is encoded in the encoding vector, how to utilize this information, and the reasons for the specific behaviors of the decoder. Structures that include attention mechanisms provide a relatively simple way to understand the reasoning process of the decoder and what the model is actually learning. Although it offers weak interpretability, it makes sense.

Confronting the Core Formula of Attention

Understanding Attention Mechanism and Its Implementation in PyTorch
When predicting the i-th word in the target language, the weight of the j-th word in the source language is Understanding Attention Mechanism and Its Implementation in PyTorch, the size of the weight can be seen as a form of soft alignment information between the source and target languages.

Summary

The use of the attention method is essentially about predicting a target word yi while automatically acquiring semantic information from different positions in the original sentence and assigning a weight to the semantic information from each position, which is the “soft” alignment information, and organizing this information to calculate the vector representation c_i for the current word yi.
Attention's PyTorch Application Implementation
import torchimport torch.nn as nnclass BiLSTM_Attention(nn.Module):    def __init__(self):        super(BiLSTM_Attention, self).__init__()        self.embedding = nn.Embedding(vocab_size, embedding_dim)        self.lstm = nn.LSTM(embedding_dim, n_hidden, bidirectional=True)        self.out = nn.Linear(n_hidden * 2, num_classes)    # lstm_output : [batch_size, n_step, n_hidden * num_directions(=2)], F matrix    def attention_net(self, lstm_output, final_state):        hidden = final_state.view(-1, n_hidden * 2, 1)   # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]        soft_attn_weights = F.softmax(attn_weights, 1)        # [batch_size, n_hidden * num_directions(=2), n_step] * [batch_size, n_step, 1] = [batch_size, n_hidden * num_directions(=2), 1]        context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)        return context, soft_attn_weights.data.numpy() # context : [batch_size, n_hidden * num_directions(=2)]    def forward(self, X):        input = self.embedding(X) # input : [batch_size, len_seq, embedding_dim]        input = input.permute(1, 0, 2) # input : [len_seq, batch_size, embedding_dim]        hidden_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]        cell_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]        # final_hidden_state, final_cell_state : [num_layers(=1) * num_directions(=2), batch_size, n_hidden]        output, (final_hidden_state, final_cell_state) = self.lstm(input, (hidden_state, cell_state))        output = output.permute(1, 0, 2) # output : [batch_size, len_seq, n_hidden]        attn_output, attention = self.attention_net(output, final_hidden_state)        return self.out(attn_output), attention # model : [batch_size, num_classes], attention : [batch_size, n_step]
Welcome to star!
https://github.com/zy1996code/nlp_basic_model/blob/master/lstm_attention.py

References

1.“Neural Network Methods in Natural Language Processing”
2.“Sequence to Sequence Learning with Neural Networks” https://arxiv.org/pdf/1409.3215.pdf
3.“LSTM Primer: Understanding Long Short-Term Memory Networks and Their PyTorch Implementation” https://zhuanlan.zhihu.com/p/86876988
4.“The Annotated Transformer” https://nlp.seas.harvard.edu/2018/04/03/attention.html

Author: Lucas

Address: https://www.zhihu.com/people/lucas_zhang

Understanding Attention Mechanism and Its Implementation in PyTorch

Understanding Attention Mechanism and Its Implementation in PyTorch

Understanding Attention Mechanism and Its Implementation in PyTorch

Recommended Historical Articles

  • Incredible! A Chinese PhD Visualized the Entire CNN, Every Detail Clearly Seen!

  • Nature Published Oxford PhD’s Suggestions: 20 Things I Wish I Knew at the Beginning of My PhD

  • Shen Xiangyang, Hua Gang: Three Levels, Four Stages, and Ten Questions for Reading Research Papers

  • How to View the Disappearance of Algorithm Positions in the Fall Recruitment of 2021?

  • Exclusive Interpretation | ExprGAN: Expression Editing Based on Intensity Control

  • Exclusive Interpretation | BP Algorithm from a Matrix Perspective

  • Exclusive Interpretation | Deep Interpretation of Capsule Networks

  • Exclusive Interpretation | Adversarial Attacks Under Fisher Information Metrics

  • Paper Interpretation | Overview of Recent Research on Knowledge Graphs

  • Did Your Graduation Thesis Pass? “How to Write a Graduation Thesis?”

  • Kalman Filtering Series – Derivation of Classic Kalman Filtering

  • A Legendary Algorithm SIFT Patent Expires!

  • Past, Present, and Future of Human Pose Estimation

  • 2018-2019 Annual Top 10 Reviews

  • Advice for New Researchers: Just Reading Papers Will Not Improve Your Skills; You Must Read Books, Read Books, Read Books!

Share, Like, and Follow, Give a Triple Hit!

Leave a Comment