Understanding the Attention Mechanism in Deep Learning – Part 2

[GiantPandaCV Guide] In recent years, Attention-based methods have gained popularity in both academia and industry due to their interpretability and effectiveness. However, the network structures proposed in papers are often embedded within code frameworks for classification, detection, segmentation, etc., leading to redundancy in code. For beginners like me, it can be challenging to find the core code of the network, which makes understanding the concepts in papers and online more difficult. Therefore, I have organized and reproduced the core code from the recent papers on Attention, MLP, and Re-parameterization for the convenience of readers. This article will briefly introduce the Attention part of this project. The project will continue to update with the latest paper work, and everyone is welcome to follow and star this work. If there are any issues during the reproduction and organization of the project, please feel free to raise them in the issues section, and I will respond promptly.~

Project Address

https://github.com/xmu-xiaoma666/External-Attention-pytorch

11. Shuffle Attention

11.1. Citation

SA-NET: Shuffle Attention For Deep Convolutional Neural Networks[1]

Paper Address: https://arxiv.org/pdf/2102.00240.pdf

11.2. Model Structure

Understanding the Attention Mechanism in Deep Learning - Part 2

11.3. Introduction

This is a paper published by Nanjing University at ICASSP 2021, which captures two types of attention: channel attention and spatial attention. The ShuffleAttention proposed in this paper mainly consists of three steps:

1. First, the input features are divided into groups, and then the features of each group are split into two branches to calculate channel attention and spatial attention, respectively. Both types of attention use trainable parameters (when I looked at the structure diagram, I thought it used FC here, but after reading the source code, I found that a set of learnable parameters is created for each channel) + sigmoid method for calculation.

2. Next, the results of the two branches are concatenated together, and merged to obtain a feature map that is consistent with the input size.

3. Finally, a shuffle layer is used for channel shuffle (similar to ShuffleNet[2]).

The authors conducted experiments on the classification dataset ImageNet-1K and the object detection dataset MS COCO, as well as instance segmentation tasks, showing that the performance of SA surpasses the current SOTA methods, achieving higher accuracy with lower model complexity.

11.4. Core Code

class ShuffleAttention(nn.Module):

    def __init__(self, channel=512,reduction=16,G=8):
        super().__init__()
        self.G=G
        self.channel=channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid=nn.Sigmoid()


    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        #group into subfeatures
        x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w

        #channel_split
        x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w

        #channel attention
        x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
        x_channel=self.cweight*x_channel+self.cweight #bs*G,c//(2*G),1,1
        x_channel=x_0*self.sigmoid(x_channel)

        #spatial attention
        x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
        x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
        x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w

        # concatenate along channel axis
        out=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,w
        out=out.contiguous().view(b,-1,h,w)

        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out
Python

11.5. Usage

from attention.ShuffleAttention import ShuffleAttention
import torch
from torch import nn
from torch.nn import functional as F


input=torch.randn(50,512,7,7)
se = ShuffleAttention(channel=512,G=8)
output=se(input)
print(output.shape)
Python

12. MUSE Attention

12.1. Citation

MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning[3]

Paper Address: https://arxiv.org/abs/1911.09483

12.2. Model Structure

Understanding the Attention Mechanism in Deep Learning - Part 2

12.3. Introduction

This is a paper published by Peking University team in 2019 on arXiv, which mainly addresses the drawback of Self-Attention (SA) having only global capture capability. As shown in the figure below, when the sentence length increases, the global capture capability of SA weakens, leading to a decline in the final model performance. Therefore, the authors introduced 1D convolutions with multiple different receptive fields to capture multi-scale local attention, compensating for SA’s shortcomings in modeling long sentences.

Understanding the Attention Mechanism in Deep Learning - Part 2

The implementation, as shown in the model structure, adds the results of SA and multiple convolutions, allowing for both global and local perception (this is quite similar to the motivation of recent works such as VOLO[4] and CoAtNet[5]). Ultimately, by introducing multi-scale local perception, the model’s performance in translation tasks has been improved.

12.4. Core Code

class Depth_Pointwise_Conv1d(nn.Module):
    def __init__(self,in_ch,out_ch,k):
        super().__init__()
        if(k==1):
            self.depth_conv=nn.Identity()
        else:
            self.depth_conv=nn.Conv1d(
                in_channels=in_ch,
                out_channels=in_ch,
                kernel_size=k,
                groups=in_ch,
                padding=k//2
                )
        self.pointwise_conv=nn.Conv1d(
            in_channels=in_ch,
            out_channels=out_ch,
            kernel_size=1,
            groups=1
        )
    def forward(self,x):
        out=self.pointwise_conv(self.depth_conv(x))
        return out
    

class MUSEAttention(nn.Module):

    def __init__(self, d_model, d_k, d_v, h,dropout=.1):


        super(MUSEAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.dropout=nn.Dropout(dropout)

        self.conv1=Depth_Pointwise_Conv1d(h * d_v, d_model,1)
        self.conv3=Depth_Pointwise_Conv1d(h * d_v, d_model,3)
        self.conv5=Depth_Pointwise_Conv1d(h * d_v, d_model,5)
        self.dy_paras=nn.Parameter(torch.ones(3))
        self.softmax=nn.Softmax(-1)

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h


    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):

        #Self Attention
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        att=self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)

        v2=v.permute(0,1,3,2).contiguous().view(b_s,-1,nk) #bs,dim,n
        self.dy_paras=nn.Parameter(self.softmax(self.dy_paras))
        out2=self.dy_paras[0]*self.conv1(v2)+self.dy_paras[1]*self.conv3(v2)+self.dy_paras[2]*self.conv5(v2)
        out2=out2.permute(0,2,1) #bs.n.dim

        out=out+out2
        return out
Python

12.5. Usage

from attention.MUSEAttention import MUSEAttention
import torch
from torch import nn
from torch.nn import functional as F


input=torch.randn(50,49,512)
sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)
output=sa(input,input,input)
print(output.shape)
Python

13. SGE Attention

13.1. Citation

Spatial Group-wise Enhance: Improving Semantic Feature Learning in Convolutional Networks[6]

Paper Address: https://arxiv.org/pdf/1905.09646.pdf

13.2. Model Structure

Understanding the Attention Mechanism in Deep Learning - Part 2

13.3. Introduction

This paper is a lightweight attention work published by the authors of SKNet[7] in 2019 on arXiv. From the core code below, it can be seen that the introduced parameters are indeed very few, with self.weight and self.bias being on the order of groups (almost constant level).

The core idea of this paper is to use the similarity of local and global information to guide the enhancement of semantic features. The overall operation can be divided into the following steps:

1) Group the features, and for each group, perform a dot product with the feature after global pooling to get the initial attention mask (similarity).

2) Normalize the attention mask by subtracting the mean and dividing by the standard deviation, while learning two scaling offset parameters for each group to make the normalization operation reversible.

3) Finally, apply sigmoid to obtain the final attention mask and scale the features at each position in the original feature group.

In the experimental section, the authors also conducted experiments on classification tasks (ImageNet) and detection tasks (COCO), achieving better performance with fewer parameters and computational load compared to networks like SK[7], CBAM[8], and BAM[9], demonstrating the efficiency of the proposed method.

13.4. Core Code


class SpatialGroupEnhance(nn.Module):

    def __init__(self, groups):
        super().__init__()
        self.groups=groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.weight=nn.Parameter(torch.zeros(1,groups,1,1))
        self.bias=nn.Parameter(torch.zeros(1,groups,1,1))
        self.sig=nn.Sigmoid()


    def forward(self, x):
        b, c, h,w=x.shape
        x=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,w
        xn=x*self.avg_pool(x) #bs*g,dim//g,h,w
        xn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,w
        t=xn.view(b*self.groups,-1) #bs*g,h*w

        t=t-t.mean(dim=1,keepdim=True) #bs*g,h*w
        std=t.std(dim=1,keepdim=True)+1e-5
        t=t/std #bs*g,h*w
        t=t.view(b,self.groups,h,w) #bs,g,h*w
        
        t=t*self.weight+self.bias #bs,g,h*w
        t=t.view(b*self.groups,1,h,w) #bs*g,1,h*w
        x=x*self.sig(t)
        x=x.view(b,c,h,w)

        return x 
Python

13.5. Usage

from attention.SGE import SpatialGroupEnhance
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,512,7,7)
sge = SpatialGroupEnhance(groups=8)
output=sge(input)
print(output.shape)
Python

14. A2 Attention

14.1. Citation

A2-Nets: Double Attention Networks[10]

Paper Address: https://arxiv.org/pdf/1810.11579.pdf

14.2. Model Structure

Understanding the Attention Mechanism in Deep Learning - Part 2

14.3. Introduction

This is a paper presented at NeurIPS 2018, which mainly focuses on spatial attention. The method in this paper is quite similar to self-attention, but the packaging is more elaborate.

The input is transformed into A, B, and V using 1×1 convolutions (similar to self-attention’s Q, K, V). The method in this paper is divided into two steps. In the first step, feature gathering is performed where A and B are multiplied to obtain an attention that aggregates global information, denoted as G. Then G is multiplied with V to obtain second-order attention. (Personally, I think this is somewhat similar to Attention on Attention (AOA)[11], the paper from ICCV 2019).

According to the experimental results, this structure performs quite well, with the authors achieving excellent results in classification (ImageNet) and action recognition (Kinetics, UCF-101) tasks, showing significant improvements compared to models like Non-Local[12] and SENet[13].

14.4. Core Code


class DoubleAttention(nn.Module):

    def __init__(self, in_channels,c_m,c_n,reconstruct = True):
        super().__init__()
        self.in_channels=in_channels
        self.reconstruct = reconstruct
        self.c_m=c_m
        self.c_n=c_n
        self.convA=nn.Conv2d(in_channels,c_m,1)
        self.convB=nn.Conv2d(in_channels,c_n,1)
        self.convV=nn.Conv2d(in_channels,c_n,1)
        if self.reconstruct:
            self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size = 1)


    def forward(self, x):
        b, c, h,w=x.shape
        assert c==self.in_channels
        A=self.convA(x) #b,c_m,h,w
        B=self.convB(x) #b,c_n,h,w
        V=self.convV(x) #b,c_n,h,w
        tmpA=A.view(b,self.c_m,-1)
        attention_maps=F.softmax(B.view(b,self.c_n,-1))
        attention_vectors=F.softmax(V.view(b,self.c_n,-1))
        # step 1: feature gating
        global_descriptors=torch.bmm(tmpA,attention_maps.permute(0,2,1)) #b.c_m,c_n
        # step 2: feature distribution
        tmpZ = global_descriptors.matmul(attention_vectors) #b,c_m,h*w
        tmpZ=tmpZ.view(b,self.c_m,h,w) #b,c_m,h,w
        if self.reconstruct:
            tmpZ=self.conv_reconstruct(tmpZ)

        return tmpZ 
Python

14.5. Usage

from attention.A2Atttention import DoubleAttention
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,512,7,7)
a2 = DoubleAttention(512,128,128,True)
output=a2(input)
print(output.shape)
Python

15. AFT Attention

15.1. Citation

An Attention Free Transformer[14]

Paper Address: https://arxiv.org/pdf/2105.14103v1.pdf

15.2. Model Structure

Understanding the Attention Mechanism in Deep Learning - Part 2

15.3. Introduction

This is a work released by the Apple team on June 16, 2021, on arXiv, which mainly focuses on simplifying Self-Attention.

In recent years, Transformers have been used in various tasks, but due to the time and space complexity of Self-Attention being quadratic with respect to the input data size, it cannot be used for very large data. In recent years, many works have been proposed to simplify the complexity of SA: sparse attention, local hashing, low-rank decomposition…

This paper proposes an Attention Free Transformer (AFT), which also consists of QKV three parts, but unlike traditional methods, QK does not perform a dot product. Instead, KV is directly fused to ensure interaction at corresponding positions, and then Q is multiplied with the fused features at the corresponding positions to reduce computational load.

Overall, the principle is similar to Self-Attention, but instead of using dot products, it uses element-wise multiplication, significantly reducing the computational load.

15.4. Core Code


class AFT_FULL(nn.Module):

    def __init__(self, d_model,n=49,simple=False):

        super(AFT_FULL, self).__init__()
        self.fc_q = nn.Linear(d_model, d_model)
        self.fc_k = nn.Linear(d_model, d_model)
        self.fc_v = nn.Linear(d_model,d_model)
        if(simple):
            self.position_biases=torch.zeros((n,n))
        else:
            self.position_biases=nn.Parameter(torch.ones((n,n)))
        self.d_model = d_model
        self.n=n
        self.sigmoid=nn.Sigmoid()

    def forward(self, input):

        bs, n,dim = input.shape

        q = self.fc_q(input) #bs,n,dim
        k = self.fc_k(input).view(1,bs,n,dim) #1,bs,n,dim
        v = self.fc_v(input).view(1,bs,n,dim) #1,bs,n,dim
        
        numerator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1))*v,dim=2) #n,bs,dim
        denominator=torch.sum(torch.exp(k+self.position_biases.view(n,1,-1,1)),dim=2) #n,bs,dim

        out=(numerator/denominator) #n,bs,dim
        out=self.sigmoid(q)*(out.permute(1,0,2)) #bs,n,dim

        return out
Python

15.5. Usage

from attention.AFT import AFT_FULL
import torch
from torch import nn
from torch.nn import functional as F

input=torch.randn(50,49,512)
aft_full = AFT_FULL(d_model=512, n=49)
output=aft_full(input)
print(output.shape)
Python

[Final Note]

Currently, the Attention works organized by this project are indeed not comprehensive enough. As the readership increases, this project will be continuously improved. Everyone is welcome to star and support. If there are any inappropriate expressions in the article or errors in the code implementation, please feel free to point them out~

[References]

[1]. Zhang, Qing-Long, and Yu-Bin Yang. “SA-NET: Shuffle Attention for Deep Convolutional Neural Networks.” ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2021.

[2]. Zhang, Xiangyu, et al. “ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices.” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.

[3]. Zhao, Guangxiang, et al. “MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning.” arXiv Preprint arXiv:1911.09483 (2019).

[4]. Yuan, Li, et al. “VOLO: Vision Outlooker for Visual Recognition.” arXiv Preprint arXiv:2106.13112 (2021).

[5]. Dai, Zihang, et al. “CoAtNet: Marrying Convolution and Attention for All Data Sizes.” arXiv Preprint arXiv:2106.04803 (2021).

[6]. Li, Xiang, Xiaolin Hu, and Jian Yang. “Spatial Group-Wise Enhance: Improving Semantic Feature Learning in Convolutional Networks.” arXiv Preprint arXiv:1905.09646 (2019).

[7]. Wu, Weikun, et al. “SK-Net: Deep Learning on Point Cloud via End-to-End Discovery of Spatial Keypoints.” Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 34. No. 04. 2020.

[8]. Woo, Sanghyun, et al. “CBAM: Convolutional Block Attention Module.” Proceedings of the European Conference on Computer Vision (ECCV). 2018.

[9]. Park, Jongchan, et al. “BAM: Bottleneck Attention Module.” arXiv Preprint arXiv:1807.06514 (2018).

[10]. Chen, Yunpeng, et al. “A2-Nets: Double Attention Networks.” arXiv Preprint arXiv:1810.11579 (2018).

[11]. Huang, Lun, et al. “Attention on Attention for Image Captioning.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019.

[12]. Wang, Xiaolong, et al. “Non-Local Neural Networks.” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.

[13]. Hu, Jie, Li Shen, and Gang Sun. “Squeeze-and-Excitation Networks.” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.

[14]. Zhai, Shuangfei, et al. “An Attention Free Transformer.” arXiv Preprint arXiv:2105.14103 (2021).

If you have any questions about the article, please feel free to leave a comment or add the author’s WeChat: xmu_xiaoma

Understanding the Attention Mechanism in Deep Learning - Part 2

Leave a Comment