Practical Implementation of PyTorch FlexAttention: Causal Attention and Variable-Length Sequence Processing Based on BlockMask

Practical Implementation of PyTorch FlexAttention: Causal Attention and Variable-Length Sequence Processing Based on BlockMask

Source: DeepHub IMBA

This article is approximately 2000 words long and is recommended for a 5-minute read.
This article introduces how to use the new FlexAttention and BlockMask features introduced in PyTorch version 2.5 and above to implement causal attention mechanisms and handle padded inputs.

Given the current lack of complete code examples and technical discussions on FlexAttention for handling padded input sequences online, this article will elaborate on an implementation method that also covers the implementation of causal attention mechanisms.

This article will not discuss the theoretical foundations of FlexAttention in detail; for more technical details, please refer to the official PyTorch blog.

Practical Implementation of PyTorch FlexAttention: Causal Attention and Variable-Length Sequence Processing Based on BlockMask

Environment Configuration

 git clone https://github.com/pytorch-labs/attention-gym.git   cd attention-gym   pip install .   cd ../

We install via the attention-gym repository to ensure compatibility between components while gaining access to its visualization tools.

MultiheadFlexAttention Implementation

To effectively use flex_attention within the transformer architecture, it needs to be implemented in the multi-head attention module.

     class MultiheadFlexAttention(nn.Module):           def __init__(self, d_in, d_out, n_heads, bias=False):               """               Description: PyTorch module implementing multi-head self-attention based on flex_attention             Parameters:                 d_in: int, input tensor dimension                 d_out: int, output tensor dimension                 n_heads: int, number of attention heads                 bias: bool, whether to use bias in query, key, and value calculations             """               super().__init__()               assert d_out % n_heads == 0, "d_out must be divisible by n_heads"                      self.n_heads = n_heads               self.d_head = d_out // n_heads               self.d_out = d_out                      self.in_proj = nn.Linear(d_in, 3 * d_out, bias=bias)               self.out_proj = nn.Linear(d_out, d_out)

This defines the core parameters of the model, including input and output dimensions and linear transformation layers.

 def forward(self, x, block_mask):               """               Description: The forward computation process of the multi-head self-attention module             Parameters:                 x: torch.Tensor, input tensor with dimensions (batch_size, max_seq_len, d_in)                 block_mask: torch.Tensor, block mask used by flex_attention             """               batch_size, max_seq_len, d_in = x.shape                       # Generate query, key, value representations through linear transformation             qkv = self.in_proj(x)                      # Decompose and reshape qkv into multi-head format             qkv = qkv.view(batch_size, max_seq_len, 3, self.n_heads, self.d_head)                      # Adjust tensor dimensions to fit flex_attention input requirements             qkv = qkv.permute(2, 0, 3, 1, 4)                      # Parse to obtain query, key, value tensors             queries, keys, values = qkv                       # Calculate attention weights using flex_attention             attn = flex_attention(queries, keys, values, block_mask=block_mask)                      # Merge multi-head attention outputs             attn = attn.transpose(1, 2).contiguous().view(batch_size, max_seq_len, self.d_out)                      # Perform output mapping             attn = self.out_proj(attn)                return attn, queries, keys

The implementation of this forward function is similar to the standard PyTorch MultiheadAttention class, with the main difference being the introduction of the block_mask parameter and the use of the flex_attention function for attention calculation.

mask_mod Function Implementation

The core advantage of FlexAttention lies in its ability to efficiently implement and use custom attention masks without writing specific CUDA core code.

To use this feature, the mask needs to be defined as a boolean tensor. First, we implement a causal mask, which is a basic example provided by the FlexAttention developers in their official blog.

Causal Mask

 def causal(b, h, q_idx, kv_idx):       return q_idx >= kv_idx

Parameter explanations:

  • b: batch size

  • h: number of attention heads

  • q_idx: query position index

  • kv_idx: key/value position index

For example, for an input sequence length of 5, q_idx is represented as torch.Tensor([0,1,2,3,4]).

q_idx >= kv_idx returns a causal boolean mask, ensuring that attention calculations only consider the current position and previous tokens.

Next, we will implement a padding mask to handle the padding part of variable-length sequences.

Padding Mask Implementation

The main difference between the padding mask and the causal mask is its batch dependency, meaning the mask values depend on the specific positions of padding tokens in each sequence. When implementing, it is necessary to identify the padding tokens to be ignored in the sequence using a padding token table.

 def create_padding_mask(pads):       def padding(b, h, q_idx, kv_idx):           return ~pads[b, q_idx] & ~pads[b, kv_idx]       return padding

pads is a boolean tensor with the shape (batch_size, max_seq_len), where padding positions are marked as True and valid token positions are marked as False. This padding mask_mod function generates a padding mask that only allows attention calculations when both query and key/value positions are non-padding tokens.

Experimental Setup and Data Preparation

Before combining masks and applying them to MultiheadFlexAttention, relevant parameters need to be set and experimental data prepared.

 # Multi-head attention parameter configuration d_in = 64   d_out = 64   n_heads = 8      # Initialize multi-head attention module mhfa = MultiheadFlexAttention(d_in, d_out, n_heads).to(device)      # Data dimension settings batch_size = 1 # Supports any batch size max_seq_len = 10      # Generate random input data input_data = torch.randn(batch_size, max_seq_len, d_in).to(device)Next, modify the input_data to add random trailing zero padding. # Add random zero padding pad = torch.zeros(1, d_in).to(device)   pad_idxs = [(b, range(torch.randint(max_seq_len//2, max_seq_len + 1, (1,)).item(), max_seq_len)) for b in range(batch_size)]   for b, idxs in pad_idxs:       input_data[b, idxs] = pad

Now, we need to construct the padding token table for the padding mask_mod function.

 # Build padding token mask collapsed_input = input_data[:, :, 0] # (batch_size, max_seq_len)   pads = torch.eq(collapsed_input, 0).to(device)

Note that the mask_mod function does not need to consider the embedding dimension of input_data, so the dimension can be compressed when creating the padding token table (pads).

Combining Causal Mask and Padding Mask

At this point, we have all the components needed to create a comprehensive attention mask.

 # Build combined mask causal_mask = causal   padding_mask = create_padding_mask(pads)   masks = [causal, padding_mask]   combined_mask = and_masks(*masks)   causal_padding_mask = create_block_mask(combined_mask, B=batch_size, H=None, Q_LEN=max_seq_len, KV_LEN=max_seq_len, _compile=True)

Here, we combine the causal and padding mask_mod functions using the and_masks function provided by torch.flex_attention to generate a unified BlockMask.

Note: The development team recommends enabling the _compile_ parameter to significantly improve the efficiency of BlockMask generation, which is especially important for batch-related mask processing.

Now we can use the MultiheadFlexAttention class to perform attention calculations on input_data while applying the compiled custom attention mask.

 # Perform forward computation attn_output, query, key = mhfa(input_data, causal_padding_mask)

Use the visualization tools provided by attention-gym to analyze the attention distribution.

 # Visualize attention distribution for the first sequence visualize_attention_scores(       query,       key,       mask_mod=combined_mask,       device=device,       name="causal_padding_mask",       path=Path("./causal_padding_mask.png"),   )

Practical Implementation of PyTorch FlexAttention: Causal Attention and Variable-Length Sequence Processing Based on BlockMask

The above image shows the causal attention distribution after masking for a sequence containing three padding tokens.

From the visualization results, it can be observed that both the attention weights for padding tokens and future tokens are effectively masked, validating the correctness of the implementation.

Editor: Yu Tengkai
Proofreader: Liang Jincheng

About Us

Data Pie THU, as a public account for data science, is backed by the Tsinghua University Big Data Research Center, sharing cutting-edge data science and big data technology innovation research dynamics, continuously disseminating data science knowledge, and striving to build a data talent gathering platform, creating the strongest group of big data in China.

Practical Implementation of PyTorch FlexAttention: Causal Attention and Variable-Length Sequence Processing Based on BlockMask

Sina Weibo: @数据派THU

WeChat Video Account: 数据派THU

Today’s Headlines: 数据派THU

Leave a Comment