Source | Zhihu
Address | https://zhuanlan.zhihu.com/p/97378498
Author | Si Jie’s Portable Mattress
Editor | Machine Learning Algorithms and Natural Language Processing Official Account
This article is authorized by the author, secondary reproduction is prohibited
Modules and functions needed:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from torch.nn.utils.rnn import pack_padded_sequence
I have three sentences that I want to input as a batch into the RNN for training.
a = [[1], [2], [3]]
b = [[4], [5]]
c = [[6]]
I define an RNN (for simplicity, let’s assume the input word vector dimension is 1 and the output dimension is 3; to fit my novice intuition, I put the batch as the first dimension and set batch_size=True).
rnn = nn.RNN(1, 3, batch_first=True)
If I only input one sample:
input = torch.FloatTensor([a]) # Note that this should be FloatTensor, not Tensor; using Tensor will cause a type mismatch error
rnn(input)
However, training only one sample at a time is too slow and cannot be parallelized, so I need to encapsulate a, b, and c into a batch and pass it to the RNN.
input = torch.FloatTensor([a, b, c]) # This line will cause an error during runtime
rnn(input)
Error message: ValueError: expected sequence of length 3 at dim 1 (got 2). What does this mean? It means: expecting a sequence of length 3, but you provided a sequence of length 2, referring to sentence b!
Error reason: A two-dimensional Tensor is a matrix, and the matrix should look like this: correct_matrix=[[[1], [2], [3]],[[4], [5], [0]],[[6], [0], [0]]] instead of looking like this: error_matrix=[[[1], [2], [3]],[[4], [5]][[6]]]. In simple terms, the length of each row in the matrix should be the same!!! However, the length of each row of our input (i.e., each sentence) is not the same!!!
Solution: We can pad the sequences to make all lengths the same! That is, by filling with zeros, we can transform error_matrix -> correct_matrix using: torch.nn.utils.rnn.pad_sequence
padded_sequence = pad_sequence([torch.FloatTensor(a), torch.FloatTensor(b), torch.FloatTensor(c)], batch_first=True)
print(padded_sequence)
rnn(padded_sequence) # Done! Now it's parallel computing
However, there is still a problem because padding with zeros wastes memory and changes the training of a length sample into training three lengths of samples, wasting computing resources and potentially affecting the model training results. So is there a solution? Since the Tensor requires each row to be of the same length, and our sentences are of different lengths, we can just not pass a Tensor as a parameter, right? What to pass instead? Pass a PackedSequence!
packed_sequence = pack_sequence([torch.FloatTensor(i) for i in [a, b, c]]) # packed_sequence is an instance of PackedSequence
print(packed_sequence)
However, the packed_sequence returned by pack_sequence has batch_first=False, which makes me unhappy. This means my RNN must also be batch_first=False, so I have to redefine the network.
rnn = nn.RNN(1, 3, batch_first=False)
print(rnn(packed_sequence)) # Done!
Finally, we can convert the results obtained from pad_sequence and pack_sequence into each other.
packed_padded_sequence = pack_padded_sequence(padded_sequence, [3, 2, 1]) # You need to pass a list indicating the length of each sentence
print(packed_padded_sequence)
padded_packed_sequence = pad_packed_sequence(packed_sequence)
print(padded_packed_sequence)
My expression ability is poor; I recommend referring to my senior Li Bo’s tutorial (although he doesn’t know me, haha).
(
Yi Zhen: How to Handle Variable Length Sequence Padding in Pytorchzhuanlan.zhihu.com
)
I rarely write blogs because I always feel I am too inexperienced, and what I write is not presentable.
Keep working hard!!!
Important! Yi Zhen's Natural Language Processing Tensorflow WeChat group has been established,
facilitating everyone to exchange Tensorflow experiences,
you can scan the QR code below to join the group,
note: please modify your remarks to [School/Company + Name + Direction] when adding.
For example - Harbin Institute of Technology + Zhang San + Dialogue System.
The host, please avoid being a micro-business. Thank you!
Recommended reading:
【Detailed Explanation】From Transformer to BERT Model
Sai Er Translation | Understanding Transformer from Scratch
Seeing is better than hearing! A hands-on guide to building a Transformer with Python