Machine Heart Editorial Team
Lightning Attention-2 is a new type of linear attention mechanism that aligns the training and inference costs of long sequences with those of a 1K sequence length.
The limitation on sequence length in large language models significantly restricts their applications in artificial intelligence, such as multi-turn dialogue, long text comprehension, and processing and generating multimodal data. The root cause of this limitation lies in the quadratic computational complexity relative to sequence length of the Transformer architecture currently adopted by large language models. This means that as the sequence length increases, the required computational resources increase geometrically. Efficiently handling long sequences has always been one of the challenges for large language models.
Previous methods often focused on how to make large language models adapt to longer sequences during the inference phase. For example, using Alibi or similar relative position encoding methods to allow the model to adapt to different input sequence lengths, or using interpolation methods on RoPE and similar relative position encodings to further fine-tune an already trained model to achieve the goal of extending sequence lengths. These methods only endow large models with a certain ability for long sequence modeling, but do not reduce the actual training and inference costs.
The OpenNLPLab team attempted to solve the long sequence problem of large language models once and for all. They proposed and open-sourced Lightning Attention-2—a new type of linear attention mechanism that aligns the training and inference costs of long sequences with those of a 1K sequence length. Before encountering GPU memory bottlenecks, infinitely increasing sequence lengths does not negatively impact model training speed. This makes infinite-length pre-training possible. Meanwhile, the inference cost for ultra-long texts is consistent with or even less than that of 1K tokens, which will significantly reduce the current inference costs of large language models. As shown in the figure below, with model sizes of 400M, 1B, and 3B, as the sequence length increases, the training speed of LLaMA powered by FlashAttention2 begins to decline rapidly, while the speed of TansNormerLLM powered by Lightning Attention-2 remains virtually unchanged.


-
Paper: Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models
-
Paper URL: https://arxiv.org/pdf/2401.04658.pdf
-
Open Source URL: https://github.com/OpenNLPLab/lightning-attention
Introduction to Lightning Attention-2
Maintaining consistent pre-training speed of large models across different sequence lengths sounds like an impossible task. In fact, if the computational complexity of an attention mechanism remains linear relative to sequence length, this can be achieved. Since the advent of linear attention in 2020, researchers have been working hard to ensure the practical efficiency of linear attention aligns with its theoretical linear computational complexity. Before 2023, most works on linear attention focused on aligning their accuracy with that of Transformers. Finally, in mid-2023, improved linear attention mechanisms achieved accuracy comparable to state-of-the-art Transformer architectures. However, the key computational trick of transforming computational complexity to linear in linear attention, known as “left multiplication becomes right multiplication” (as shown in the figure below), is far slower in actual implementation than direct left multiplication algorithms. This is because right multiplication requires cumulative summation (cumsum) involving many loop operations, and the extensive I/O operations make right multiplication far less efficient than left multiplication.

To better understand the ideas behind Lightning Attention-2, let’s first review the calculation formula of traditional softmax attention: O=softmax((QK^T)⊙M_)V, where Q, K, V, M, O represent the query, key, value, mask, and output matrices respectively. Here, M is a lower triangular matrix of all 1s in unidirectional tasks (like GPT), and can be ignored in bidirectional tasks (like Bert), meaning that bidirectional tasks do not have a mask matrix.
The author summarizes the overall idea of Lightning Attention-2 into the following three points for explanation:
1. One of the core ideas of Linear Attention is to eliminate the computationally expensive softmax operator, allowing the attention calculation formula to be written as O=((QK^T)⊙M_)V. However, due to the presence of the mask matrix M in unidirectional tasks, this form can still only perform left multiplication, thus not achieving O(N) complexity. But for bidirectional tasks, since there is no mask matrix, the Linear Attention calculation formula can be further simplified to O=(QK^T)V. The brilliance of Linear Attention lies in the fact that by simply leveraging the associative property of matrix multiplication, its calculation formula can be further transformed into: O=Q(K^T V), which is known as right multiplication, while the former is left multiplication. As shown in Figure 2, it is intuitively understood that Linear Attention can achieve an enticing O(N) complexity in bidirectional tasks!
2. However, as decoder-only GPT-style models gradually become the de facto standard for LLMs, how to leverage the right multiplication feature of Linear Attention to accelerate unidirectional tasks has become an urgent problem to solve. To address this issue, the author proposes to use a “divide and conquer” approach, dividing the computation of the attention matrix into diagonal and non-diagonal forms, and employing different methods to compute them. As shown in Figure 3, Linear Attention-2 utilizes the common tiling concept in computer science, where the Q, K, V matrices are divided into the same number of blocks. The computation within each block (intra-block) still retains the left multiplication method due to the presence of the mask matrix, resulting in O(N^2) complexity; while the computation between blocks (inter-block) can utilize the right multiplication method due to the absence of a mask matrix, thus enjoying O(N) complexity. After both are computed, the outputs for the corresponding i-th block can be directly summed to obtain the Linear Attention output Oi. Meanwhile, the cumulative sum of KV states is performed for use in the next block’s computation. This results in an overall algorithm complexity for Lightning Attention-2 of O(N^2) for intra-block and O(N) for inter-block. How to achieve a better trade-off is determined by the block size of tiling.
3. Observant readers may notice that the above process is only the algorithm part of Lightning Attention-2. The reason for naming it Lightning is that the author fully considers the efficiency of this algorithm during GPU hardware execution. Inspired by the FlashAttention series of works, when performing computations on the GPU, the author moves the split Q_i, K_i, V_i tensors from the slower, larger capacity HBM to the faster, smaller SRAM for computation, thereby reducing a significant amount of memory I/O overhead. Once the block completes the Linear Attention computation, its output O_i is moved back to HBM. This process is repeated until all blocks are processed.
Readers interested in more details can carefully read Algorithm 1 and Algorithm 2 in this article, as well as the detailed derivation process in the paper. The algorithms and derivations distinguish between the forward and backward processes of Lightning Attention-2, which can help readers gain a deeper understanding.



Accuracy Comparison of Lightning Attention-2
Researchers first compared the accuracy difference between Lightning Attention-2 and Lightning Attention-1 on a small-scale (400M) parameter model, as shown in the figure below; the two are almost indistinguishable.

Subsequently, researchers compared the TansNormerLLM (TNL-LA2) powered by Lightning Attention-2 with other advanced non-Transformer architectures and LLaMA powered by FlashAttention2 on the same corpus at 1B and 3B. As shown in the figure below, TNL-LA2 maintained a similar trend to LLaMA, with superior loss performance. This experiment indicates that Lightning Attention-2 demonstrates accuracy performance comparable to state-of-the-art Transformer architectures in language modeling.

In large language model tasks, researchers compared the results of TNL-LA2 15B with Pythia on common benchmarks for large models of similar sizes. As shown in the table below, under the condition of consuming the same tokens, TNL-LA2 slightly outperformed the Pythia model based on softmax attention in common sense reasoning and multiple-choice comprehensive abilities.

Speed Comparison of Lightning Attention-2
Researchers compared the speed and memory usage of Lightning Attention-2 with FlashAttention2 on a single module. As shown in the figure below, compared to Lightning Attention-1 and FlashAttention2, Lightning Attention-2 exhibited a strict linear increase in speed relative to sequence length. In terms of memory usage, all three showed similar trends, but Lightning Attention-2 had lower memory usage. This is because FlashAttention2 and Lightning Attention-1 also have approximately linear memory usage.

The author notes that the main focus of this article is on solving the training speed of linear attention networks, achieving training speeds for arbitrary lengths of long sequences similar to those of 1K sequences. There is not much introduction to inference speed. This is because linear attention can be losslessly transformed into RNN mode during inference, achieving similar effects, i.e., a constant speed for inferring a single token. For Transformers, the inference speed of the current token is related to the number of previous tokens.
The author tested the inference speed comparison between TransNormerLLM-7B powered by Lightning Attention-1 and common 7B models. As shown in the figure below, at approximately the same parameter size, the throughput speed of Lightning Attention-1 is four times that of Baichuan and over 3.5 times that of ChatGLM, demonstrating excellent inference speed advantages.

Lightning Attention-2 represents a significant advancement in linear attention mechanisms, perfectly replacing traditional softmax attention in both accuracy and speed, providing sustainable scalability for increasingly larger models, and offering a pathway to efficiently process infinitely long sequences. The OpenNLPLab team will explore sequence parallel algorithms based on linear attention mechanisms in the future to address the current memory barrier issues.
© THE END
For reprints, please contact this public account for authorization
Submissions or inquiries: [email protected]