Mastering Linear State Space: Building a Mamba Neural Network from Scratch

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

Author: Kuang Ji
Reviewed by: Los
In the field of deep learning, sequence modeling remains a challenging task, typically addressed by models such as LSTMs and Transformers. However, these models have substantial computational costs, leading to significant drawbacks in practical applications. Mamba is a linear time series modeling framework designed to improve the efficiency and effectiveness of sequence modeling. This article will delve into the process of implementing Mamba using PyTorch, decoding the technical issues and code behind this innovative approach.

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

■1.1 Transformer:
Transformers are known for their attention mechanisms. With the operational characteristics of Transformers, any part of the feature sequence can dynamically interact with other parts, especially the causal attention features, which effectively capture information about causal characteristics. Therefore, Transformers can handle each element in the sequence well, but correspondingly, the computational costs and memory costs of Transformers are also high, proportional to the square of the sequence length (L²).
■1.2 Recurrent Neural Networks (RNN):
RNNs update hidden states in sequence order, considering only the current input features and the previous hidden state information. This approach allows them to handle sequences of infinite length with constant memory costs. However, the simplicity of RNNs also becomes a drawback, limiting their ability to remember long-term dependencies. Moreover, despite innovations like LSTMs, the backpropagation through time (BPTT) mechanism in RNNs can consume significant memory and may encounter gradient vanishing or explosion issues.
■1.3 State Space Models (S4):
State space models have good properties. They provide a balance between computational costs and memory costs, capturing long-range dependencies more efficiently than RNNs while being more memory-efficient than Transformers.

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

▲Figure 1|Development of Sequence Modeling Network Architecture ©️【Deep Blue AI】

■1.4. Methodology of the Mamba Architecture:
●Selective State Space:Mamba is based on the concept of state space models and introduces a new model architecture design approach. It utilizes a selective state space to capture relevant information in long sequences more efficiently and effectively.
●Linear Time Complexity:Unlike Transformers, Mamba’s runtime is linearly related to the sequence length. This feature makes it particularly suitable for tasks involving ultra-long sequences, where traditional models struggle.

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

▲Figure 2|Mamba Introduces Selective State Space ©️【Deep Blue AI】

Mamba introduces a novel architecture to traditional state space models through its concept of “Selective State Spaces.” This approach slightly relaxes the rigid state transitions of standard state space models, making it more adaptable and flexible, somewhat akin to LSTMs. However, Mamba retains the efficient computational characteristics of state space models, allowing it to perform forward passes for the entire sequence at once.

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

■2.1 Import Required Libraries

After a brief introduction to the Mamba architecture, let’s look at the code implementation process of Mamba, starting with importing the necessary libraries.
# Importing PyTorch related libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from einops import rearrange
from tqdm import tqdm
# Importing system related libraries
import math
import os
import urllib.request
from zipfile import ZipFile
from transformers import AutoTokenizer
torch.autograd.set_detect_anomaly(True)

■2.2 Setting Identifiers and Training Device

This section focuses on whether to use GPU and the corresponding identifiers for Mamba settings, as well as the device used.
# Configuration identifiers and hyperparameters
USE_MAMBA = 1
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0
# Set the device used
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

← Swipe left/right to view the complete code →

■2.3 Set Initialization Hyperparameters

This subsection defines hyperparameters such as model dimensions (d_model), state size, sequence length, and batch size.
# Manually defined hyperparameters
d_model = 8
state_size = 128  # State size
seq_len = 100  # Sequence length
batch_size = 256  # Batch size
last_batch_size = 81  # Last batch size
current_batch_size = batch_size
different_batch_size = False
h_new = None
temp_buffer = None

■2.4 Define S6 Module

The S6 module is a complex component of the Mamba architecture, comprising a series of linear transformations and discretization processes to handle the input feature sequences. It plays a crucial role in capturing the temporal dynamic features of sequences, which are key aspects of sequence modeling tasks such as language modeling.
# Define S6 module
class S6(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(S6, self).__init__()
        # A series of linear transformations
        self.fc1 = nn.Linear(d_model, d_model, device=device)
        self.fc2 = nn.Linear(d_model, state_size, device=device)
        self.fc3 = nn.Linear(d_model, state_size, device=device)
        # Set some hyperparameters
        self.seq_len = seq_len
        self.d_model = d_model
        self.state_size = state_size
        self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1))
        # Parameter initialization
        nn.init.xavier_uniform_(self.A)
        self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
        self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
        self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
        self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
        self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
        # Define internal parameters h and y
        self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
        self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)

    # Discretization function
    def discretization(self):
        # Discretization function definition as described on page 28 of the Mamba paper
        self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)
        # dA = torch.matrix_exp(A * delta)  # matrix_exp() only supports square matrix
        self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))
        return self.dA, self.dB

    # Forward propagation
    def forward(self, x):
        # Refer to Algorithm 2 in the Mamba paper
        self.B = self.fc2(x)
        self.C = self.fc3(x)
        self.delta = F.softplus(self.fc1(x))
        # Discretization
        self.discretization()
        if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:  # If not using 'h_new', will trigger local allowed error
            global current_batch_size
            current_batch_size = x.shape[0]
            if self.h.shape[0] != current_batch_size:
                different_batch_size = True
                # Scale h's dimensions to match the current batch
                h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB
            else:
                different_batch_size = False
                h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB

            # Change y's dimensions
            self.y = torch.einsum('bln,bldn->bld', self.C, h_new)

            # Update h's information based on h_new
            global temp_buffer
            temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()
            return self.y
        else:  # Will trigger an error
            # Set h's dimensions
            h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)
            y = torch.zeros_like(x)
            h = torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB
            # Set y's dimensions
            y = torch.einsum('bln,bldn->bld', self.C, h)
            return y
← Swipe left/right to view the complete code →
The S6 module inherits from nn.Module and is a key part of the Mamba algorithm model, responsible for handling the discretization process and forward propagation.

■2.5 Define MambaBlock Module

The MambaBlock module is a custom neural network module and a key component of the Mamba model. It encapsulates multiple network layers and operations for processing input data. The MambaBlock module represents a complex neural network module, including linear projections, convolutions, activation functions, custom S6 modules, and residual connections. This module is a fundamental component of the Mamba model, processing input sequences through a series of transformations to capture relevant patterns and features in the data. The combination of these different network layers and operations allows MambaBlock to effectively handle complex sequence modeling tasks.
# Define MambaBlock module
class MambaBlock(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(MambaBlock, self).__init__()
        self.inp_proj = nn.Linear(d_model, 2*d_model, device=device)
        self.out_proj = nn.Linear(2*d_model, d_model, device=device)
        # Residual connection
        self.D = nn.Linear(d_model, 2*d_model, device=device)
        # Setting bias property
        self.out_proj.bias._no_weight_decay = True
        # Initialize bias
        nn.init.constant_(self.out_proj.bias, 1.0)
        # Initialize S6 module
        self.S6 = S6(seq_len, 2*d_model, state_size, device)
        # Add 1D convolution
        self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device)
        # Add linear layer
        self.conv_linear = nn.Linear(2*d_model, 2*d_model, device=device)
        # Normalization
        self.norm = RMSNorm(d_model, device=device)
    # Forward propagation
    def forward(self, x):
        # Refer to Figure 3 in the Mamba paper
        x = self.norm(x)
        x_proj = self.inp_proj(x)
        # 1D convolution operation
        x_conv = self.conv(x_proj)
        x_conv_act = F.silu(x_conv)  # Swish activation
        # Linear operation
        x_conv_out = self.conv_linear(x_conv_act)
        # S6 module operation
        x_ssm = self.S6(x_conv_out)
        x_act = F.silu(x_ssm)  # Swish activation
        # Residual connection
        x_residual = F.silu(self.D(x))
        x_combined = x_act * x_residual
        x_out = self.out_proj(x_combined)
        return x_out
← Swipe left/right to view the complete code →
The MambaBlock module encapsulates the core functionalities of Mamba, including input projections, 1D convolutions, and S6 modules.

■2.6 Define Mamba Model

The Mamba class represents the overall architecture of the Mamba model, consisting of a series of MambaBlock modules. Each module is responsible for handling input sequence data, with the output of one module serving as the input for the next. This sequential processing allows the model to capture complex patterns and relationships in the input data, effectively completing sequence modeling tasks. Stacking multiple modules is a common design in deep learning architectures, as it enables the model to learn hierarchical representations of the data.
# Define Mamba model
class Mamba(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(Mamba, self).__init__()
        self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device)
        self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device)
        self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device)
    def forward(self, x):
        x = self.mamba_block1(x)
        x = self.mamba_block2(x)
        x = self.mamba_block3(x)
        return x
← Swipe left/right to view the complete code →
This class defines the entire Mamba model, linking multiple MambaBlock modules to form the architecture of the overall algorithm model.

■2.7 Define RMSNorm Module

The RMSNorm module is a custom normalization layer that inherits from PyTorch’s nn.Module. This layer is used to normalize the activation values of the neural network, which helps speed up training.
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float=1e-5, device: str='cuda'):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model, device=device))
    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
        return output
← Swipe left/right to view the complete code →
The RMSNorm module is a commonly used technique in neural network architectures for normalization.

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

This section introduces how to instantiate and use the Mamba algorithm model on simple data samples.
# Create simulated data
x = torch.rand(batch_size, seq_len, d_model, device=device)
# Create Mamba algorithm model
mamba = Mamba(seq_len, d_model, state_size, device)
# Define RMSNorm module
norm = RMSNorm(d_model)
x = norm(x)
# Forward propagation
test_output = mamba(x)
print(f"test_output.shape = {test_output.shape}")

■3.1 Data Preparation and Training Function

The Enwiki8Dataset class is a custom dataset handler that inherits from PyTorch’s Dataset class, specifically designed for building datasets for sequence modeling tasks (such as language modeling).
# Define padding function
def pad_sequences_3d(sequences, max_len=None, pad_value=0):
    # Get the dimensions of the tensor
    batch_size, seq_len, feature_size = sequences.shape
    if max_len is None:
        max_len = seq_len + 1
    # Initialize padded_sequences
    padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)
    # Fill each sequence
    padded_sequences[:, :seq_len, :] = sequences
    return padded_sequences
← Swipe left/right to view the complete code →
The train function is used to train the Mamba algorithm model.
def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
●model (Model): The neural network model to be trained (in this case, Mamba);
●tokenizer: The tokenizer for processing input data;
●data_loader: The data loader, an iterable for providing batched data for training;
●optimizer: Optimizer: The optimization algorithm used to update model weights;
●criterion: The loss function used to evaluate model performance;
●device: The device on which the model runs (CPU or GPU);
●max_grad_norm: The value for gradient clipping to prevent gradient explosion;
●DEBUGGING_IS_ON: A flag to enable debugging information.
# Define train function
def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
    model.train()
    total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()
        input_data = batch['input_ids'].clone().to(device)
        attention_mask = batch['attention_mask'].clone().to(device)
        # Get input data and labels
        target = input_data[:, 1:]
        input_data = input_data[:, :-1]
        # Pad sequence data
        input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
        target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)
        if USE_MAMBA:
            output = model(input_data)
            loss = criterion(output, target)
        loss.backward(retain_graph=True)  # Retain graph for backward pass
        # Clip gradients
        for name, param in model.named_parameters():
            if 'out_proj.bias' not in name:
                # Gradient clipping operation
                torch.nn.utils.clip_grad_norm_(param, max_norm=max_grad_norm)
        if DEBUGGING_IS_ON:
            for name, parameter in model.named_parameters():
                if parameter.grad is not None:
                    print(f"{name} gradient: {parameter.grad.data.norm(2)}")
                else:
                    print(f"{name} has no gradient")
        if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
            model.S6.h[:current_batch_size, ...].copy_(temp_buffer)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(data_loader)
← Swipe left/right to view the complete code →

■3.2 Model Training Loop

# Input pre-trained model weights
encoded_inputs_file = 'encoded_inputs_mamba.pt'
if os.path.exists(encoded_inputs_file):
    print("Loading pre-tokenized data...")
    encoded_inputs = torch.load(encoded_inputs_file)
else:
    print("Tokenizing raw data...")
    enwiki8_data = load_enwiki8_dataset()
    encoded_inputs, attention_mask = encode_dataset(tokenizer, enwiki8_data)
    torch.save(encoded_inputs, encoded_inputs_file)
    print(f"finished tokenizing data")
# Combine data
data = {
    'input_ids': encoded_inputs,
    'attention_mask': attention_mask
}
# Split training and validation sets
total_size = len(data['input_ids'])
train_size = int(total_size * 0.8)
train_data = {key: val[:train_size] for key, val in data.items()}
val_data = {key: val[train_size:] for key, val in data.items()}
train_dataset = Enwiki8Dataset(train_data)
val_dataset = Enwiki8Dataset(val_data)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Initialize model
model = Mamba(seq_len, d_model, state_size, device).to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-6)
# Number of training epochs
num_epochs = 25
for epoch in tqdm(range(num_epochs)):
    train_loss = train(model, tokenizer, train_loader, optimizer, criterion, device, max_grad_norm=10.0, DEBUGGING_IS_ON=False)
    val_loss = evaluate(model, val_loader, criterion, device)
    val_perplexity = calculate_perplexity(val_loss)
    print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}')
← Swipe left/right to view the complete code →
The above code provides a detailed example of establishing and training the Mamba model, including dataset combination and partitioning, model definition and initialization, loss function and optimizer definition, and finally setting the number of training loops.

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

This article provides the complete code process for building Mamba from scratch. Readers can use the explanations and code in this article to transform the Mamba algorithm model from theory to concrete practice. This exploratory process not only reinforces understanding of the internal workings of Mamba but also demonstrates the practical design steps of innovative algorithm model architectures. With this knowledge, the author can now better attempt to use Mamba in their own projects or delve deeper into developing new AI models.

References:

[1]https://arxiv.org/abs/2312.00752
[2]https://github.com/state-spaces/mamba
[3]https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
[4]https://huggingface.co/datasets/enwik8

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

Building a Complete NeRF from Scratch with PyTorch

2024-01-31

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

Detailed Explanation of NeRF Principles and Code (2)

2023-12-01

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

Detailed Explanation of NeRF Principles and Code (1)

2023-11-24

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

【Deep Blue AI】 is recruiting authors for long-term. We welcome those who want to transform their scientific and technical experiences into words to share with a wider audience. If you want to join, please click the link below to learn more👇

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

The Deep Blue AI author team is strongly recruiting! We look forward to your joining.

Mastering Linear State Space: Building a Mamba Neural Network from Scratch

【Deep Blue AI】‘s original content is created with the author’s personal efforts. We hope everyone follows the original rules and cherishes the authors’ hard work. For reprints, please contact us for authorization and be sure to indicate that it comes from【Deep Blue AI】WeChat official account, otherwise legal action will be taken.

*Click to view, collect, and recommend this article*

Leave a Comment