New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!

Click on the aboveBeginner’s Guide to Vision” to choose to addto favorites or “pin

Important information delivered promptly

Reprinted from: Machine Heart | Edited by: Chen Chen

Try a new attention pattern with FlexAttention.

In theory, the attention mechanism is everything you need. However, in practice, we also need to optimize implementations of attention mechanisms like FlashAttention.
Although these fused attention mechanisms greatly improve performance and support long contexts, this efficiency gain comes with a loss of flexibility. For machine learning researchers, this is like a “software lottery”—if your attention variant is not compatible with existing optimized kernels, you will face slow performance and CUDA out-of-memory issues.
Some attention variants include causal attention, relative position embeddings, Alibi, sliding window attention, PrefixLM, document masking, irregular tensors, PagedAttention, and more. Worse still, people often want to combine these variants! For example, sliding window attention + document masking + causal attention + context parallelism, or the combination of PagedAttention + sliding window.
The left side of the image below represents the current situation—some combinations of masks + biases + settings already have existing kernel implementations. However, adding various options leads to an exponential growth of configurations. Even worse, this approach does not support new attention variants.
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
To thoroughly solve this hypercube problem, the PyTorch team introduced FlexAttention, a new PyTorch API.
  1. FlexAttention is a flexible API that allows users to implement multiple attention variants with just a few lines of idiomatic PyTorch code.
  2. The team reduced it to a fused FlashAttention kernel through torch.compile, generating a FlashAttention kernel that does not occupy extra memory and performs comparably to handwritten kernels.
  3. It automatically generates backpropagation using PyTorch’s automatic differentiation mechanism.
  4. Finally, the PyTorch team can also leverage the sparsity in attention masks to significantly improve standard attention implementations.

New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!

Tri Dao, a participant in FlashAttention versions 1-3, retweeted this research and commented: This research integrates many technologies.
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
FlexAttention
The classic attention equation is as follows:
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
In code form:
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
FlexAttention takes the following form, resolving the aforementioned issues by accepting a user-defined function score_mod.
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
In code form:
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
This function allows users to modify attention scores before softmax. Researchers found that this function is sufficient to meet the needs of most users regarding attention variants.
Specifically, score_mod is as follows:
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
To apply this function, it can be implemented as:
for b in range (batch_size):    for h in range (num_heads):        for q_idx in range (sequence_length):            for kv_idx in range (sequence_length):                modified_scores [b, h, q_idx, kv_idx] = score_mod (scores [b, h, q_idx, kv_idx], b, h, q_idx, kv_idx)
The final API has surprising expressiveness.
Score Mod Example
Full Attention
In this case, score_mod does nothing; it takes scores as input and returns them as is.
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
Then end-to-end usage.
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
Relative Position Encoding
A common attention variant is relative position encoding. Relative position encoding does not encode absolute distances in queries and keys, but adjusts scores based on the distance between queries and keys.
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
It is important to note that, unlike typical implementations, this does not require materializing the SxS tensor. Instead, FlexAttention dynamically computes bias values in the kernel, significantly improving memory and performance.
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
Soft-capping
Soft-capping is a technique used in Gemma 2 and Grok-1, and in FlexAttention, it takes the following form:
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
Causal Mask
Although bidirectional attention is simple, in the paper “Attention is All You Need” and in other LLMs, their setups are attention with only decoders, where each token can only attend to the tokens before it. If users use the score_mod API, it can be represented as:
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
Sliding Window + Causal
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
Image Source: https://arxiv.org/abs/2310.06825
Mistral has been promoting sliding window attention (also known as local attention), which allows query tokens to only attend to the most recent 1024 tokens, usually used in conjunction with causal attention.
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
Researchers benchmarked F.scaled_dot_product_attention with sliding window masks and FA2 with causal masks. The results show that FlexAttention is not only significantly faster than F.scaled_dot_product_attention but also significantly faster than FA2 with causal masks.
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
Performance
Overall, the performance of FlexAttention is nearly as good as that of handwritten Triton kernels. However, due to the generality of FlexAttention, there is a slight performance loss. For example, users must tolerate some additional latency.
FlexAttention achieves 90% of FlashAttention2 performance in forward propagation and 85% in backward propagation. FlexAttention currently uses a deterministic algorithm that recalculates more intermediates than FAv2, and researchers plan to improve the backward algorithm of FlexAttention to narrow this gap!
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
New PyTorch API: Implementing Different Attention Variants with Just a Few Lines of Code!
Reference Link: https://pytorch.org/blog/flexattention/
Download 1: OpenCV-Contrib Extension Module Chinese Tutorial

Reply "Chinese Tutorial for Extension Module" in the background of the "Beginner's Guide to Vision" public account to download the first Chinese version of the OpenCV extension module tutorial available online, covering over 20 chapters on extension module installation, SFM algorithms, stereo vision, object tracking, biological vision, super-resolution processing, etc.

Download 2: Python Vision Practical Projects 52 Lectures

Reply "Python Vision Practical Projects" in the background of the "Beginner's Guide to Vision" public account to download 31 vision practical projects including image segmentation, mask detection, lane detection, vehicle counting, eyeliner addition, license plate recognition, character recognition, emotion detection, text content extraction, facial recognition, etc., to help quickly learn computer vision.

Download 3: OpenCV Practical Projects 20 Lectures

Reply "OpenCV Practical Projects 20 Lectures" in the background of the "Beginner's Guide to Vision" public account to download 20 practical projects based on OpenCV, achieving advanced learning of OpenCV.

Group Chat

Welcome to join the reader group of the public account to communicate with peers. Currently, there are WeChat groups on SLAM, 3D vision, sensors, autonomous driving, computational photography, detection, segmentation, recognition, medical imaging, GANs, algorithm competitions, etc. (will gradually subdivide in the future). Please scan the WeChat ID below to join the group, and note: "Nickname + School/Company + Research Direction", for example: "Zhang San + Shanghai Jiao Tong University + Vision SLAM". Please follow the format, otherwise, you will not be approved. After successful addition, you will be invited to the relevant WeChat group based on your research direction. Please do not send advertisements in the group, otherwise, you will be removed. Thank you for your understanding~



Leave a Comment