Detailed Explanation of ViT Model and PyTorch Implementation

Introduction

Using PyTorch to implement the ViT model from scratch, training the ViT model on the CIFAR-10 dataset for image classification.
Detailed Explanation of ViT Model and PyTorch Implementation

Architecture of ViT

The architecture of ViT is inspired by BERT, which is a transformer model that uses only encoders, typically used for supervised learning tasks in NLP such as text classification or named entity recognition. The main idea behind ViT is that images can be viewed as a series of patches, analogous to tokens in NLP tasks.

Detailed Explanation of ViT Model and PyTorch Implementation

The input image is divided into small patches and then flattened into a sequence of vectors. These vectors are then processed by a transformer encoder, allowing the model to learn interactions between patches through self-attention mechanisms. The output of the transformer encoder is then fed into a classification layer, which outputs the predicted class of the input image.

Code Implementation

Below is the PyTorch code implementation for various components of the model.

01

Image Embedding Transformation

To feed the input image into the Transformer model, we need to convert the image into a series of vectors. This is done by dividing the image into a non-overlapping patch grid and then linearly projecting these patches to obtain a fixed-size embedding vector for each patch. For this, we can use the PyTorch layer: nn.Conv2d
class PatchEmbeddings(nn.Module):
    """
    Convert the image into patches and then project them into a vector space.
    """

    def __init__(self, config):
        super().__init__()
        self.image_size = config["image_size"]
        self.patch_size = config["patch_size"]
        self.num_channels = config["num_channels"]
        self.hidden_size = config["hidden_size"]
        # Calculate the number of patches from the image size and patch size
        self.num_patches = (self.image_size // self.patch_size) ** 2
        # Create a projection layer to convert the image into patches
        # The layer projects each patch into a vector of size hidden_size
        self.projection = nn.Conv2d(self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
        # (batch_size, num_channels, image_size, image_size) -> (batch_size, num_patches, hidden_size)
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x
kernel_size=self.patch_size ensures that the layer’s filters are applied to non-overlapping patches. stride=self.patch_sizeAfter the patches are converted to an embedding sequence, the [CLS] token is added to the beginning of the sequence, which will later be used for classification in the classification layer. The embedding of the [CLS] token is learned during training.
Since the contributions of patches from different positions to the final prediction may vary, we also need a way to encode the patch positions into the sequence. We will use learnable position embedding vectors to add positional information to the embedding vectors. This is similar to how position embeddings are used in Transformer models for NLP tasks.
class Embeddings(nn.Module):        
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.patch_embeddings = PatchEmbeddings(config)
        # Create a learnable [CLS] token
        # Similar to BERT, the [CLS] token is added to the beginning of the input sequence
        # and is used to classify the entire sequence
        self.cls_token = nn.Parameter(torch.randn(1, 1, config["hidden_size"]))
        # Create position embeddings for the [CLS] token and the patch embeddings
        # Add 1 to the sequence length for the [CLS] token
        self.position_embeddings = 
            nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config["hidden_size"]))
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x):
        x = self.patch_embeddings(x)
        batch_size, _, _ = x.size()
        # Expand the [CLS] token to the batch size
        # (1, 1, hidden_size) -> (batch_size, 1, hidden_size)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        # Concatenate the [CLS] token to the beginning of the input sequence
        # This results in a sequence length of (num_patches + 1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.position_embeddings
        x = self.dropout(x)
        return x
In this step, the input image is converted into an embedding sequence with positional information and is prepared to be fed into the transformer layers.

02

Multi-Head Attention

Before introducing the transformer encoder, we first explore the multi-head attention module, which is its core component. Multi-head attention is used to compute interactions between different color blocks in the input image. Multi-head attention consists of multiple attention heads, each of which is an attention layer.
Let’s implement the head of the multi-head attention module. This module takes a series of embedding vectors as input and computes the query, key, and value vectors for each embedding vector. Attention weights for each token are then calculated using the query and key vectors. The new embedding is computed using the weighted sum of the value vectors based on the attention weights. We can think of this mechanism as a soft version of a database query, where the query vector searches for the most relevant key vectors in the database and retrieves the value vectors to compute the query output.
class AttentionHead(nn.Module):
    """
    A single attention head.
    This module is used in the MultiHeadAttention module.
    """
    def __init__(self, hidden_size, attention_head_size, dropout, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.attention_head_size = attention_head_size
        # Create the query, key, and value projection layers
        self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)
        self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Project the input into query, key, and value
        # The same input is used to generate the query, key, and value,
        # so it's usually called self-attention.
        # (batch_size, sequence_length, hidden_size) -> (batch_size, sequence_length, attention_head_size)
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        # Calculate the attention scores
        # softmax(Q*K.T/sqrt(head_size))*V
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        # Calculate the attention output
        attention_output = torch.matmul(attention_probs, value)
        return (attention_output, attention_probs)
Then, the outputs of all attention heads are concatenated and linearly projected to obtain the final output of the multi-head attention module.
class MultiHeadAttention(nn.Module):    """    Multi-head attention module.    This module is used in the TransformerEncoder module.    """
    def __init__(self, config):        super().__init__()        self.hidden_size = config["hidden_size"]        self.num_attention_heads = config["num_attention_heads"]        # The attention head size is the hidden size divided by the number of attention heads        self.attention_head_size = self.hidden_size // self.num_attention_heads        self.all_head_size = self.num_attention_heads * self.attention_head_size        # Whether or not to use bias in the query, key, and value projection layers        self.qkv_bias = config["qkv_bias"]        # Create a list of attention heads        self.heads = nn.ModuleList([])        for _ in range(self.num_attention_heads):            head = AttentionHead(                self.hidden_size,                self.attention_head_size,                config["attention_probs_dropout_prob"],                self.qkv_bias            )            self.heads.append(head)        # Create a linear layer to project the attention output back to the hidden size        # In most cases, all_head_size and hidden_size are the same        self.output_projection = nn.Linear(self.all_head_size, self.hidden_size)        self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])

    def forward(self, x, output_attentions=False):        # Calculate the attention output for each attention head        attention_outputs = [head(x) for head in self.heads]        # Concatenate the attention outputs from each attention head        attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1)        # Project the concatenated attention output back to the hidden size        attention_output = self.output_projection(attention_output)        attention_output = self.output_dropout(attention_output)        # Return the attention output and the attention probabilities (optional)        if not output_attentions:            return (attention_output, None)        else:            attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)            return (attention_output, attention_probs)

03

Encoder

The encoder consists of a stack of MHA + MLP. Each transformer layer is primarily composed of the multi-head attention module we just implemented and a feed-forward network. To better scale the model and stabilize training, two layer normalization layers and skip connections are added to the transformer layers.
Let’s implement a transformer layer (referred to as Block in the code), which is a building block of the transformer encoder. We will start with the feed-forward network, which is a simple two-layer MLP with GELU activation in between.
class MLP(nn.Module):    """    A multi-layer perceptron module.    """
    def __init__(self, config):        super().__init__()        self.dense_1 = nn.Linear(config["hidden_size"], config["intermediate_size"])        self.activation = NewGELUActivation()        self.dense_2 = nn.Linear(config["intermediate_size"], config["hidden_size"])        self.dropout = nn.Dropout(config["hidden_dropout_prob"])
    def forward(self, x):        x = self.dense_1(x)        x = self.activation(x)        x = self.dense_2(x)        x = self.dropout(x)        return x
We have implemented the multi-head attention and MLP, and we can combine them to create a transformer layer. Skip connections and layer normalization will be applied to the input of each layer.
class Block(nn.Module):    """    A single transformer block.    """
    def __init__(self, config):        super().__init__()        self.attention = MultiHeadAttention(config)        self.layernorm_1 = nn.LayerNorm(config["hidden_size"])        self.mlp = MLP(config)        self.layernorm_2 = nn.LayerNorm(config["hidden_size"])
    def forward(self, x, output_attentions=False):        # Self-attention        attention_output, attention_probs = 
            self.attention(self.layernorm_1(x), output_attentions=output_attentions)        # Skip connection        x = x + attention_output        # Feed-forward network        mlp_output = self.mlp(self.layernorm_2(x))        # Skip connection        x = x + mlp_output        # Return the transformer block's output and the attention probabilities (optional)        if not output_attentions:            return (x, None)        else:            return (x, attention_probs)
The transformer encoder stacks multiple transformer layers sequentially:
class Encoder(nn.Module):    """    The transformer encoder module.    """
    def __init__(self, config):        super().__init__()        # Create a list of transformer blocks        self.blocks = nn.ModuleList([])        for _ in range(config["num_hidden_layers"]):            block = Block(config)            self.blocks.append(block)
    def forward(self, x, output_attentions=False):        # Calculate the transformer block's output for each block        all_attentions = []        for block in self.blocks:            x, attention_probs = block(x, output_attentions=output_attentions)            if output_attentions:                all_attentions.append(attention_probs)        # Return the encoder's output and the attention probabilities (optional)        if not output_attentions:            return (x, None)        else:            return (x, all_attentions)

04

Building the ViT Model

After inputting the image into the embedding layer and transformer encoder, we obtain new embeddings for image patches and the [CLS] token. At this point, the embeddings should contain useful signals for classification after being processed by the transformer encoder. Similar to BERT, we will only use the embedding of the [CLS] token to pass to the classification layer.
The classification layer is a fully connected layer that takes the [CLS] embedding as input and outputs the logits for each image. The following code implements the ViT model for image classification:
class ViTForClassfication(nn.Module):    """    The ViT model for classification.    """
    def __init__(self, config):        super().__init__()        self.config = config        self.image_size = config["image_size"]        self.hidden_size = config["hidden_size"]        self.num_classes = config["num_classes"]        # Create the embedding module        self.embedding = Embeddings(config)        # Create the transformer encoder module        self.encoder = Encoder(config)        # Create a linear layer to project the encoder's output to the number of classes        self.classifier = nn.Linear(self.hidden_size, self.num_classes)        # Initialize the weights        self.apply(self._init_weights)
    def forward(self, x, output_attentions=False):        # Calculate the embedding output        embedding_output = self.embedding(x)        # Calculate the encoder's output        encoder_output, all_attentions = self.encoder(embedding_output, output_attentions=output_attentions)        # Calculate the logits, take the [CLS] token's output as features for classification        logits = self.classifier(encoder_output[:, 0])        # Return the logits and the attention probabilities (optional)        if not output_attentions:            return (logits, None)        else:            return (logits, all_attentions)

References

The code is actually organized and translated from GitHub (I believe it is very easy to understand, and anyone with a basic knowledge of PyTorch can comprehend and learn it). If interested, you can check it out here:

https://github.com/lukemelas/PyTorch-Pretrained-ViT/blob/master/pytorch_pretrained_vit/transformer.pyhttps://tintn.github.io/Implementing-Vision-Transformer-from-Scratch/

Editor / Zhang Zhihong

Reviewer / Fan Ruiqiang

Reviewed by / Fan Ruiqiang

Click below

Follow us

Read the original text

Leave a Comment