Distilling Llama3 into Hybrid Linear RNN with Mamba

Distilling Llama3 into Hybrid Linear RNN with Mamba

MLNLP community is a well-known machine learning and natural language processing community in China and abroad, covering NLP master’s and doctoral students, university teachers, and researchers from enterprises.
The Community’s Vision is to promote communication and progress between the academic and industrial circles of natural language processing and machine learning at home and abroad, especially for the progress of beginner students.
Reprinted from | Machine Heart

The key to the success of Transformer in the field of deep learning is the attention mechanism. The attention mechanism allows Transformer-based models to focus on parts of the input sequence that are relevant, achieving better context understanding. However, the drawback of the attention mechanism is its high computational cost, which grows quadratically with the input size, making it difficult for Transformers to handle very long texts.

Recently, the emergence of Mamba has broken this situation, allowing for linear scaling with increasing context length. With the release of Mamba, these state space models (SSM) can compete with and even surpass Transformers on medium and small scales while maintaining linear scalability with respect to sequence length, giving Mamba advantageous deployment characteristics.

In simple terms, Mamba first introduces a simple yet effective selection mechanism, which can reparameterize SSM based on the input, allowing the model to filter out irrelevant information while indefinitely retaining necessary and relevant data.

Recently, a paper titled “The Mamba in the Llama: Distilling and Accelerating Hybrid Models” demonstrated that large transformers can be distilled into large hybrid linear RNNs by reusing the weights of the attention layers, requiring minimal additional computation while retaining most of their generative quality.

The resulting hybrid model contains a quarter of the attention layers and achieves performance comparable to the original Transformer in chat benchmark tests, outperforming open-source hybrid Mamba models trained from scratch using trillions of tokens in both chat and general benchmark tests. Additionally, the study proposed a hardware-aware speculative decoding algorithm to accelerate the inference speed of Mamba and hybrid models.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Paper link: https://arxiv.org/pdf/2408.15237

The best-performing model from this study was distilled from Llama3-8B-Instruct, achieving a length-controlled win rate of 29.61 against GPT-4 on AlpacaEval 2, and a win rate of 7.35 on MT-Bench, surpassing the best instruction-tuned linear RNN model.

Method

Knowledge distillation (KD) is a model compression technique used to transfer knowledge from a large model (teacher model) to a smaller model (student model), aiming to train the student network to mimic the behavior of the teacher network. This study aims to distill the Transformer to match the performance of the original language model.

The study proposed a multi-level distillation method that combines progressive distillation, supervised fine-tuning, and directed preference optimization. Compared to ordinary distillation, this method achieves better perplexity and downstream evaluation results.

The study assumes that most of the knowledge from the Transformer is retained in the MLP layers migrated from the original model and focuses on the fine-tuning and alignment steps of distilling LLMs. During this stage, the MLP layers remain frozen while the Mamba layers are trained.

Distilling Llama3 into Hybrid Linear RNN with Mamba

The study believes that there are inherent connections between linear RNNs and the attention mechanism. By removing softmax, the attention formula can be linearized:

Distilling Llama3 into Hybrid Linear RNN with Mamba

However, linearizing attention can lead to degraded model capability. To design an effective distilled linear RNN, the study aims to stay as close to the original Transformer parameterization as possible while effectively scaling the capacity of the linear RNN. The study does not attempt to make the new model capture the exact original attention function but uses the linearized form as a starting point for distillation.

As shown in Algorithm 1, the study directly feeds the standard Q, K, V heads from the attention mechanism into the Mamba discretization, and then applies the resulting linear RNN. This can be seen as a rough initialization using linear attention, allowing the model to learn richer interactions through expanded hidden states.

Distilling Llama3 into Hybrid Linear RNN with Mamba

The study directly replaces the Transformer attention heads with fine-tuned linear RNN layers while keeping the Transformer MLP layers unchanged and not training them. This approach also needs to handle other components, such as cross-head shared keys and values for grouped query attention. The research team noted that this architecture differs from many architectures used in Mamba systems, and this initialization allows any attention block to be replaced with a linear RNN block.

Distilling Llama3 into Hybrid Linear RNN with Mamba

The study also proposed a new algorithm for speculative decoding of linear RNNs using hardware-aware multi-step generation.

Algorithm 2 and Figure 2 show the complete algorithm. This method only retains one RNN hidden state in the cache for verification and delays advancing it based on the success of the multi-step kernel. Since the distilled model contains transformer layers, the study also extends speculative decoding to Attention/RNN hybrid architectures. In this setup, the RNN layers perform verification according to Algorithm 2, while the Transformer layers only perform parallel verification.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Distilling Llama3 into Hybrid Linear RNN with Mamba

To validate the effectiveness of this method, the study used Mamba 7B and Mamba 2.8B as target models for inference. The results are shown in Table 1.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Figure 3 shows the performance characteristics of the multi-step kernel itself.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Acceleration on H100 GPU. The algorithm proposed by the study shows strong performance on Ampere GPUs, as shown in Table 1. However, it faces significant challenges on H100 GPUs. This is mainly due to the speed of GEMM operations being too fast, making the overhead from caching and recomputation more apparent. In fact, the simple implementation of the study’s algorithm (using multiple different kernel calls) achieved considerable acceleration on the 3090 GPU but showed no acceleration on the H100 at all.

Experiments and Results

The study conducted experiments using two LLM chat models: Zephyr-7B, which is fine-tuned based on the Mistral 7B model, and Llama-3 Instruct 8B. For the linear RNN model, the study used a hybrid version of Mamba and Mamba2, with attention layers set to 50%, 25%, 12.5%, and 0%, with 0% referred to as the pure Mamba model. Mamba2 is a variant architecture of Mamba designed primarily for recent GPU architectures.

Evaluation on Chat Benchmarks

Table 2 shows the performance of the models on chat benchmarks, with the main comparison being large Transformer models. The results show:

The distilled hybrid Mamba model (50%) achieved scores in the MT benchmark test similar to the teacher model and slightly outperformed the teacher model in length-controlled win rates and overall win rates in the AlpacaEval benchmark test.

The performance of the distilled hybrid Mamba (25% and 12.5%) was slightly inferior to the teacher model in the MT benchmark test, but even with more parameters in AlpacaEval, it still surpassed some large Transformers.

The accuracy of the distilled pure (0%) Mamba model did decline significantly.

Notably, the performance of the distilled hybrid model outperformed Falcon Mamba, which was trained from scratch using over 5 trillion tokens.

Distilling Llama3 into Hybrid Linear RNN with Mamba

General Benchmark Evaluation

Zero-shot evaluation. Table 3 shows the zero-shot performance of Mamba and Mamba2 distilled from different teacher models in the LM Eval benchmark. The hybrid Mamba-Llama3 and Mamba2-Llama3 models distilled from Llama-3 Instruct 8B performed better compared to open-source TRI Mamba and Nvidia Mamba models trained from scratch.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Benchmark Evaluation. Table 4 shows the performance of the distilled hybrid model matching the best open-source linear RNN model on the Open LLM Leaderboard, while outperforming the corresponding open-source instruction models in GSM8K and CRUX.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Hybrid Speculative Decoding

For the 50% and 25% distilled models, the study achieved over 1.8 times acceleration on Zephyr-Hybrid compared to the non-speculative baseline.

The experiments also showed that the 4-layer draft model trained achieved higher acceptance rates, but the additional overhead also increased due to the larger scale of the draft model. In future work, the study will focus on reducing the size of these draft models.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Comparison with Other Distillation Methods: Table 6 (left) compares the perplexity of different model variants. The study used Ultrachat as seed prompts for distillation within one epoch and compared perplexity. The results found that removing more layers made the situation worse. The study also compared the distillation method with previous baselines and found that the new method showed less degradation, while the Distill Hyena model was trained on a much smaller model in the WikiText dataset and showed greater perplexity degradation.

Table 6 (right) shows that using SFT or DPO alone does not yield significant improvements, while using SFT + DPO produces the best scores.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Table 7 compares several different models’ ablation studies. Table 7 (left) shows the distillation results using various initializations, and Table 7 (right) shows that the benefits of progressive distillation and interleaving attention layers with Mamba are minimal.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Table 8 compares the performance of hybrid models using two different initialization methods: the results confirm that the initialization of attention weights is crucial.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Table 9 compares the performance of models with and without Mamba blocks. Models with Mamba blocks significantly outperformed those without Mamba blocks. This confirms that adding Mamba layers is crucial, and the performance improvement is not solely due to the remaining attention mechanism.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Interested readers can read the original paper for more research content.

Technical Discussion Group Invitation

Distilling Llama3 into Hybrid Linear RNN with Mamba

△ Long press to add the assistant

Scan the QR code to add the assistant on WeChat

Please note: Name-School/Company-Research Direction
(e.g., Xiao Zhang-Harbin Institute of Technology-Dialogue System)
to apply to join the Natural Language Processing/Pytorch and other technical discussion groups

About Us

MLNLP community is a grassroots academic community jointly built by domestic and foreign machine learning and natural language processing scholars, which has now developed into a well-known machine learning and natural language processing community at home and abroad, aiming to promote progress between the academic and industrial circles of machine learning and natural language processing and the vast number of enthusiasts.
The community can provide an open communication platform for the further study, employment, and research of related practitioners. Everyone is welcome to follow and join us.

Distilling Llama3 into Hybrid Linear RNN with Mamba

Leave a Comment