| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from enum import Enum |
| from dataclasses import dataclass, field |
| from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states |
| from mamba_ssm.ops.triton.selective_state_update import selective_state_update |
|
|
| |
| from .causal_conv1d_compilable import causal_conv1d_fn, causal_conv1d_update |
| from .ssm_compilable import mamba_chunk_scan_combined |
| |
|
|
| from .norms import build_norm |
| from .attn import AttentionLayer |
| from .attn import precompute_freqs_cis |
| from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined |
|
|
|
|
| class InitStdFactor(Enum): |
| DISABLED = "disabled" |
| GLOBAL_DEPTH = "global_depth" |
| CURRENT_DEPTH = "current_depth" |
| DIM_RATIO = "dim_ratio" |
|
|
|
|
| @dataclass |
| class InitConfig: |
| dt_max: float = 0.1 |
| dt_min: float = 0.001 |
|
|
| dt_init_floor: float = 1e-4 |
|
|
| A_init_min: float = 1 |
| A_init_max: float = 16 |
|
|
|
|
| DEFAULT_INIT_CONFIG = InitConfig() |
|
|
|
|
| @dataclass |
| class BaseMambaConfig: |
| """ |
| Configuration for the Mamba family of models. |
| """ |
| dim: int = 512 |
| num_layers: int = 8 |
| num_heads: int = 8 |
|
|
| state_dim: int = 128 |
| num_groups: int = 1 |
| conv_size: int | None = 4 |
|
|
| bias: bool = False |
| conv_bias: bool = True |
| dt_bias: bool = False |
| D_has_head_dim: bool = False |
| learnable_init_states: bool = False |
|
|
| mlp_scale: int = 2 |
| multiple_of: int = 256 |
|
|
| norm_eps: float = 1e-6 |
| norm_type: str = "rmsnorm" |
| |
| |
| ssm_chunk_size: int = 256 |
| use_mem_eff_path: bool = False |
|
|
| |
| init_use_depth: bool = False |
| init_base_std: float | None = None |
| init_std_factor: str = "disabled" |
| init_config: InitConfig = field(default_factory=InitConfig) |
|
|
|
|
| class SSM(nn.Module): |
| """ |
| State Space Model (SSM) implementation with selective state updates and convolution. |
| |
| Implements the core SSM computation with support for both training and inference modes. |
| During inference, uses cached states for efficient token-by-token generation. |
| """ |
| def __init__(self, config: BaseMambaConfig) -> None: |
| """Initialize SSM parameters and layers. |
| Args: |
| config: Configuration containing model hyperparameters |
| """ |
| super().__init__() |
| self.config = config |
| vars(self).update(vars(config)) |
|
|
| assert self.dim > 0, "Model dimension (config.dim) must be positive" |
| assert self.num_heads > 0, "Number of heads (config.num_heads) must be positive" |
| assert self.state_dim > 0, "State dimension (config.state_dim) must be positive" |
|
|
| if self.mlp_scale is None: |
| raise ValueError( |
| "mlp_scale must be set to a valid float (e.g. 2.0) " |
| "to determine hidden_dim in SSM." |
| ) |
| assert self.mlp_scale > 0, "mlp_scale must be > 0" |
|
|
| self.hidden_dim = int(self.mlp_scale * self.dim) |
| self.hidden_dim = config.multiple_of * ( |
| (self.hidden_dim + self.multiple_of - 1) // self.multiple_of |
| ) |
| |
| assert self.hidden_dim % self.num_heads == 0, ( |
| f"Hidden dim {self.hidden_dim} not divisible by num_heads={self.num_heads}." |
| ) |
|
|
| self.head_dim = self.hidden_dim // self.num_heads |
|
|
| self.dt_limit_kwargs = {} |
| dt_limit = (self.init_config.dt_min, self.init_config.dt_max) |
| if dt_limit != (0.0, float("inf")): |
| self.dt_limit_kwargs = dict(dt_limit=dt_limit) |
|
|
| |
| d_input = ( |
| 2 * self.hidden_dim |
| + 2 * self.num_groups * self.state_dim |
| + self.num_heads |
| ) |
|
|
| self.input = nn.Linear(self.dim, d_input, bias=self.bias) |
|
|
| |
| if self.conv_size is not None: |
| conv_dim = self.hidden_dim + 2 * self.num_groups * self.state_dim |
|
|
| |
| |
| self.conv1d = nn.Conv1d( |
| in_channels=conv_dim, |
| out_channels=conv_dim, |
| kernel_size=self.conv_size, |
| groups=conv_dim, |
| bias=self.conv_bias, |
| padding=self.conv_size - 1 |
| ) |
|
|
| if config.dt_bias: |
| self.dt_bias = nn.Parameter(torch.empty(self.num_heads)) |
| else: |
| self.dt_bias = nn.Parameter(torch.zeros(self.num_heads), requires_grad=False) |
|
|
| self.A_log = nn.Parameter(torch.empty(self.num_heads)) |
|
|
| if config.D_has_head_dim: |
| self.D = nn.Parameter(torch.ones(self.num_heads, self.head_dim)) |
| else: |
| self.D = nn.Parameter(torch.ones(self.num_heads)) |
| |
| if self.learnable_init_states: |
| self.init_states = nn.Parameter(torch.zeros(self.num_heads, self.head_dim, self.state_dim)) |
|
|
| self.norm = build_norm(config.norm_type, dim=self.hidden_dim, eps=self.norm_eps) |
| self.output = nn.Linear(self.hidden_dim, self.dim, bias=self.bias) |
|
|
| def _causal_conv( |
| self, |
| zxbcdt: torch.Tensor, |
| tok_idx: torch.Tensor | None = None, |
| cu_seqlens: torch.Tensor | None = None, |
| ssm_impl: str = "ssm" |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| """Processes input through causal convolution path, handling both full sequence and incremental cases. |
| |
| This function implements two processing modes: |
| 1. Full sequence ("ssm"): Used during training and initial prompt processing. |
| 2. Incremental ("ssm_update"): Used during token-by-token generation. |
| |
| Args: |
| zxbcdt: Input tensor containing concatenated [z, x, B, C, dt] components |
| tok_idx: Token indices for sequence processing. Required for "ssm" mode. |
| Defaults to None. |
| cu_seqlens: Cumulative sequence lengths for variable length processing. |
| Used only in "ssm" mode with caching. Defaults to None. |
| ssm_impl: Implementation mode, either "ssm" for full sequence processing |
| or "ssm_update" for incremental generation. Defaults to "ssm". |
| |
| Returns: |
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| Tuple containing separated components (z, x, B, C, dt), where: |
| - z: Gating branch |
| - x: Main branch |
| - B, C: SSM state matrices (analogous to K, Q in attention) |
| - dt: Time delta values |
| |
| Notes: |
| - When using "ssm" mode during inference, a cache should be pre-initialized |
| externally. This design allows for flexible caching strategies without |
| modifying model code. |
| - The "ssm_update" mode requires a cache to exist and will use it for |
| incremental state updates during generation. |
| - B, C components correspond to Key, Query in the SSM/attention duality. |
| """ |
| |
| z, xBC, dt = torch.split( |
| zxbcdt, |
| [ |
| self.hidden_dim, |
| self.hidden_dim + 2 * self.num_groups * self.state_dim, |
| self.num_heads, |
| ], |
| dim=-1, |
| ) |
|
|
| if ssm_impl == "ssm": |
| if hasattr(self, "cache"): |
| conv_varlen_states = causal_conv1d_varlen_states( |
| xBC.squeeze(0), |
| cu_seqlens, |
| state_len=self.cache.conv_cache.shape[-1], |
| ) |
| self.cache.conv_cache.copy_(conv_varlen_states) |
|
|
| xBC = causal_conv1d_fn( |
| x=xBC.transpose(1, 2), |
| weight=self.conv1d.weight.squeeze(1), |
| bias=self.conv1d.bias, |
| activation="silu", |
| seq_idx=tok_idx, |
| ).transpose(1, 2) |
| elif ssm_impl == "ssm_update": |
| xBC = causal_conv1d_update( |
| x=xBC.squeeze(0), |
| conv_state=self.cache.conv_cache, |
| weight=self.conv1d.weight.squeeze(1), |
| bias=self.conv1d.bias, |
| activation="silu", |
| ).unsqueeze(0) |
| else: |
| raise NotImplementedError(f"SSM implementation {ssm_impl} not supported") |
|
|
| |
| x, B, C = torch.split( |
| xBC, |
| [ |
| self.hidden_dim, |
| self.num_groups * self.state_dim, |
| self.num_groups * self.state_dim, |
| ], |
| dim=-1, |
| ) |
|
|
| return z, x, B, C, dt |
|
|
| def _non_causal_conv(self, zxbcdt: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| z, x, B, C, dt = torch.split( |
| zxbcdt, |
| [ |
| self.hidden_dim, |
| self.hidden_dim, |
| self.num_groups * self.state_dim, |
| self.num_groups * self.state_dim, |
| self.num_heads, |
| ], |
| dim=-1, |
| ) |
| return z, x, B, C, dt |
|
|
| def _fwd(self, x, dt, A, B, C, tok_idx, cu_seqlens, initial_states): |
| """ |
| For training |
| |
| Returns: |
| (bsz, seq_len, num_heads, head_dim) |
| """ |
| y = mamba_chunk_scan_combined( |
| x, |
| dt, |
| A, |
| B, |
| C, |
| dt_bias=self.dt_bias, |
| dt_softplus=True, |
| chunk_size=self.ssm_chunk_size, |
| D=self.D, |
| z=None, |
| seq_idx=tok_idx, |
| cu_seqlens=cu_seqlens, |
| initial_states=initial_states, |
| **self.dt_limit_kwargs, |
| ) |
|
|
| if hasattr(self, "cache"): |
| y, varlen_states = y |
| self.cache.state_cache.copy_(varlen_states) |
|
|
| return y |
| |
| def _step(self, x, seq_len, dt, A, B, C): |
| """ |
| For inference / generation. |
| """ |
| x = x.squeeze(0) |
| A = A[..., None, None].expand(self.num_heads, self.head_dim, self.state_dim) |
| dt = dt.permute(1, 2, 0).expand(seq_len, self.num_heads, self.head_dim) |
| D = self.D |
| if D is not None and D.dim() == 1: |
| D = D.unsqueeze(1).expand(self.num_heads, self.head_dim) |
| B, C = B.squeeze(0), C.squeeze(0) |
| y = selective_state_update( |
| self.cache.state_cache, |
| x, |
| dt, |
| A, |
| B, |
| C, |
| D, |
| z=None, |
| dt_bias=( |
| torch.zeros(self.num_heads, self.head_dim).to(x) |
| if self.dt_bias is None |
| else self.dt_bias.unsqueeze(1).expand(self.num_heads, self.head_dim) |
| ), |
| dt_softplus=True, |
| ).unsqueeze(0) |
|
|
| return y |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| tok_idx: torch.Tensor | None = None, |
| cu_seqlens: torch.Tensor | None = None, |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| bsz, seq_len, _ = x.shape |
|
|
| zxbcdt = self.input(x) |
|
|
| A = -torch.exp(self.A_log.float()) |
| initial_states = ( |
| self.init_states.expand(bsz, -1, -1, -1) |
| if self.learnable_init_states else None |
| ) |
|
|
| |
| if self.conv_size is not None: |
| |
| |
| if self.use_mem_eff_path: |
| out = mamba_split_conv1d_scan_combined( |
| zxbcdt, |
| self.conv1d.weight.squeeze(1), |
| self.conv1d.bias, |
| self.dt_bias, |
| A, |
| D=self.D, |
| chunk_size=self.ssm_chunk_size, |
| seq_idx=tok_idx, |
| activation="silu", |
| rmsnorm_weight=self.norm.weight, |
| rmsnorm_eps=self.norm.eps, |
| outproj_weight=self.output.weight, |
| outproj_bias=self.output.bias, |
| headdim=self.head_dim, |
| ngroups=self.num_groups, |
| norm_before_gate=False, |
| initial_states=initial_states, |
| **self.dt_limit_kwargs, |
| ) |
| return out |
| else: |
| |
| z, x, B, C, dt = self._causal_conv(zxbcdt) |
| else: |
| |
| z, x, B, C, dt = self._non_causal_conv(zxbcdt) |
|
|
| x = x.view(bsz, seq_len, self.num_heads, self.head_dim) |
| B = B.view(bsz, seq_len, self.num_groups, self.state_dim) |
| C = C.view(bsz, seq_len, self.num_groups, self.state_dim) |
|
|
| |
| if ssm_impl == "ssm": |
| |
| y = self._fwd(x, dt, A, B, C, tok_idx, cu_seqlens, initial_states) |
| elif ssm_impl == "ssm_update": |
| y = self._step(x, seq_len, dt, A, B, C) |
| else: |
| raise NotImplementedError(f"SSM implementation {ssm_impl} not supported") |
|
|
| y = y.view(bsz, seq_len, self.hidden_dim) |
|
|
| |
| |
| |
| y = self.norm(y * F.silu(z)) |
| out = self.output(y) |
|
|
| return out |
|
|
| @torch.inference_mode() |
| def reset_parameters(self, init_std, factor) -> None: |
| config = self.config |
| init_config = config.init_config |
| if init_config is None: |
| init_config = DEFAULT_INIT_CONFIG |
|
|
| |
| in_init_std = init_std or (self.dim ** (-0.5)) |
| out_init_std = init_std or (self.hidden_dim ** (-0.5)) |
| out_init_std = out_init_std / factor |
|
|
| nn.init.trunc_normal_( |
| self.input.weight, |
| mean=0.0, |
| std=in_init_std, |
| a=-3 * in_init_std, |
| b=3 * in_init_std, |
| ) |
|
|
| nn.init.trunc_normal_( |
| self.output.weight, |
| mean=0.0, |
| std=out_init_std, |
| a=-3 * out_init_std, |
| b=3 * out_init_std, |
| ) |
|
|
| |
| if self.dt_bias is not None and self.dt_bias.requires_grad: |
| log_dt_min = math.log(init_config.dt_min) |
| log_dt_max = math.log(init_config.dt_max) |
| |
| |
| log_dt = torch.rand(self.num_heads, device=self.dt_bias.device) * (log_dt_max - log_dt_min) + log_dt_min |
| dt = torch.exp(log_dt) |
| dt = torch.clamp(dt, min=init_config.dt_init_floor) |
|
|
| |
| inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| self.dt_bias.copy_(inv_dt) |
|
|
| elif self.dt_bias is not None: |
| |
| self.dt_bias.fill_(0.0) |
|
|
| |
| if self.conv_size is not None: |
| conv_std = init_std or (self.conv_size ** (-0.5)) |
| nn.init.trunc_normal_( |
| self.conv1d.weight, |
| mean=0.0, |
| std=conv_std, |
| a=-3 * conv_std, |
| b=3 * conv_std, |
| ) |
| if self.conv1d.bias is not None: |
| nn.init.zeros_(self.conv1d.bias) |
|
|
| |
| if self.learnable_init_states: |
| self.init_states.zero_() |
|
|
| |
| self.A_log.uniform_(init_config.A_init_min, init_config.A_init_max) |
| self.A_log.log_() |
|
|
| if self.D is not None: |
| self.D.data.fill_(1.0) |
|
|
| |
| self.norm.reset_parameters() |
|
|
|
|
| class MambaBlock(nn.Module): |
| def __init__(self, config: BaseMambaConfig): |
| super().__init__() |
| self.norm = build_norm(config.norm_type, dim=config.dim, eps=config.norm_eps) |
| self.ssm = SSM(config) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| tok_idx: torch.Tensor | None, |
| cu_seqlens: torch.Tensor | None, |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| x = x + self.ssm(self.norm(x), tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl) |
| return x |
|
|
| @torch.inference_mode() |
| def init_weights(self, init_std=None, factor=1.0): |
| self.norm.reset_parameters() |
| self.ssm.reset_parameters(init_std, factor) |
|
|
|
|
| class BaseMamba(nn.Module): |
| def __init__(self, config: BaseMambaConfig): |
| super().__init__() |
| assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})" |
| self.head_dim = config.dim // config.num_heads |
|
|
| self.model_dim = config.dim |
| self.init_base_std = config.init_base_std |
|
|
| self.init_config = config.init_config |
| self.init_std_factor = InitStdFactor(config.init_std_factor) |
|
|
| |
| self.register_buffer("freqs_cis", precompute_freqs_cis( |
| head_dim=self.head_dim, |
| max_seq_len=config.seq_len, |
| theta=config.theta, |
| ), persistent=True) |
|
|
| self.layers = nn.ModuleList() |
| for layer_idx in range(config.num_layers): |
| |
| if layer_idx % 2 == 0: |
| self.layers.append(MambaBlock(config)) |
| else: |
| self.layers.append( |
| AttentionLayer(config) |
| if config.use_attn |
| else (MambaBlock(config)) |
| ) |
|
|
| def _unwrap(self, layer: nn.Module) -> nn.Module: |
| """Helper function to find the underlying layer name (if wrapped in DDP or FSDP)""" |
| while hasattr(layer, "module"): |
| layer = layer.module |
| return layer |
|
|
| def forward( |
| self, |
| h: torch.Tensor, |
| tok_idx: torch.Tensor | None, |
| cu_seqlens: torch.Tensor | None, |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| for layer in self.layers: |
| unwrapped_layer = self._unwrap(layer) |
| if isinstance(unwrapped_layer, MambaBlock): |
| h = unwrapped_layer(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl) |
| elif isinstance(unwrapped_layer, AttentionLayer): |
| h = unwrapped_layer(h, self.freqs_cis) |
| else: |
| raise ValueError(f"ERROR: Unexpected layer type: {type(unwrapped_layer).__name__}") |
| return h |
|
|
| @torch.inference_mode() |
| def reset_parameters(self): |
| pass |
|
|
| @torch.inference_mode() |
| def init_weights(self): |
| self.reset_parameters() |
| for depth, layer in enumerate(self.layers): |
| factor = { |
| InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, |
| InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, |
| InitStdFactor.DIM_RATIO: self.model_dim / 4096, |
| InitStdFactor.DISABLED: 1.0, |
| }[self.init_std_factor] |
|
|
| if not hasattr(layer, "attn"): |
| layer.init_weights(self.init_base_std, factor) |
|
|
|
|
| @dataclass |
| class Mamba2Config(BaseMambaConfig): |
| seed: int = 1337 |
|
|
| vocab_size: int = -1 |
| seq_len: int = 8192 |
| window_size: int = 1024 |
| weight_tying: bool = False |
| torch_dtype: torch.dtype = torch.bfloat16 |
|
|
| loss_reduction: str = "mean" |
|
|
| use_attn: bool = True |
| use_alibi: bool = True |
| dropout: float = 0.0 |
| softcap: float = 50.0 |
| theta: float = 10000.0 |
|
|
| device: torch.device = None |
| dtype: torch.dtype = torch.bfloat16 |
|
|
|
|
| class Mamba2(BaseMamba): |
| def __init__(self, config: Mamba2Config) -> None: |
| super().__init__(config) |
| self.weight_tying = config.weight_tying |
| self.loss_reduction = config.loss_reduction |
|
|
| assert config.vocab_size > 0, "vocab_size must be set and > 0" |
|
|
| self.tok_emb = torch.nn.Embedding(config.vocab_size, config.dim) |
|
|
| self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps) |
|
|
| self.output = nn.Linear( |
| config.dim, |
| config.vocab_size, |
| bias=False, |
| ) |
|
|
| if config.weight_tying: |
| self.output.weight = self.tok_emb.weight |
|
|
| print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,)) |
|
|
| def _get_num_params(self): |
| n_params = sum(p.numel() for p in self.parameters()) |
|
|
| if hasattr(self, "pos_emb") and self.pos_emb is not None: |
| n_params -= self.pos_emb.weight.numel() |
|
|
| return n_params |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| target: torch.Tensor | None = None, |
| tok_idx: torch.Tensor | None = None, |
| cu_seqlens: torch.Tensor | None = None, |
| ssm_impl: str = "ssm", |
| ) -> torch.Tensor: |
| h = self.tok_emb(x) |
| h = super().forward(h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl) |
| logits = self.output(self.norm(h)) |
| return logits |
|
|
| @torch.inference_mode() |
| def reset_parameters(self, init_std=None): |
| |
| super().reset_parameters() |
| init_std = init_std or (self.model_dim ** (-0.5)) |
| self.norm.reset_parameters() |
| nn.init.trunc_normal_( |
| self.tok_emb.weight, |
| mean=0.0, |
| std=init_std, |
| a=-3 * init_std, |
| b=3 * init_std, |
| ) |
| if not self.weight_tying: |
| nn.init.trunc_normal_( |
| self.output.weight, |
| mean=0.0, |
| std=init_std, |
| a=-3 * init_std, |
| b=3 * init_std, |
| ) |
|
|
| @torch.inference_mode() |
| def init_weights(self, buffer_device: torch.device = None): |
| """ |
| Initialize model parameters and optionally compute buffers on a specific device. |
| |
| Args: |
| buffer_device (torch.device, optional): If provided, any large or precomputed |
| buffers (like RoPE frequency tensors) will be allocated or re-created on |
| this device during initialization. This can avoid overhead from transferring |
| buffers between CPU and GPU after creation. If None, buffers default to the |
| device of the first parameter or CPU. |
| |
| Usage: |
| - Pass a GPU device (e.g., ``torch.device('cuda')``) when you want to ensure |
| buffers are created directly on GPU, preventing extra transfers. |
| - Pass a CPU device (e.g., ``torch.device('cpu')``) if you want to keep |
| large buffers in CPU memory (common in CPU-offload or pipeline-parallel setups). |
| - Leave it as ``None`` to rely on the model’s existing parameter device or |
| the default PyTorch device context. |
| |
| When / Why: |
| - Useful in distributed or pipeline-parallel training where parameters may |
| initially live on CPU, but you still need certain buffers on GPU to avoid |
| overhead during forward passes. |
| - Prevents large re-allocations or re-copies when big buffers (like RoPE |
| frequency tables) are needed per rank. |
| """ |
| super().init_weights() |
|
|
| @classmethod |
| def from_model_args(cls, config: Mamba2Config) -> "Mamba2": |
| """ |
| Initialize a Mamba model from a MambaConfig object. |
| |
| Args: |
| config (MambaConfig): Mamba configuration arguments. |
| |
| Returns: |
| Mamba: Mamba-2 model. |
| """ |
| return cls(config) |
|
|
|
|
| if __name__ == '__main__': |
| import json |
|
|
| config_path = "config.json" |
|
|
| with open(config_path, "r") as f: |
| config_data = json.load(f) |
|
|
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| elif torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| else: |
| device = torch.device("cpu") |
| print("Device:", device) |
|
|
| torch_dtype = getattr(torch, config_data["torch_dtype"]) |
| print("Torch dtype:", torch_dtype) |
|
|
| dim = config_data["dim"] |
| num_heads = config_data["num_heads"] |
| num_layers = config_data["num_layers"] |
| vocab_size = config_data["vocab_size"] |
| bias = config_data["bias"] |
| state_dim = config_data["state_dim"] |
| num_groups = config_data["num_groups"] |
| conv_size = config_data.get("conv_size") |
| use_mem_eff_path = config_data.get("use_mem_eff_path") |
| dt_bias = config_data["dt_bias"] |
| D_has_head_dim = config_data["D_has_head_dim"] |
| learnable_init_states = config_data["learnable_init_states"] |
| ssm_chunk_size = config_data["ssm_chunk_size"] |
| weight_tying = config_data["weight_tying"] |
| mlp_scale = config_data.get("mlp_scale") |
| multiple_of = config_data["multiple_of"] |
| norm_eps = config_data["norm_eps"] |
| init_use_depth = config_data["init_use_depth"] |
| init_base_std = config_data.get("init_base_std") |
| init_std_factor = config_data["init_std_factor"] |
| use_attn = config_data["use_attn"] |
| softcap = config_data["softcap"] |
| torch_compile = config_data["torch_compile"] |
|
|
| configs = Mamba2Config( |
| dim=dim, |
| num_layers=num_layers, |
| num_heads=num_heads, |
| vocab_size=vocab_size, |
| bias=bias, |
| torch_dtype=torch_dtype, |
| state_dim=state_dim, |
| num_groups=num_groups, |
| conv_size=conv_size, |
| use_mem_eff_path=use_mem_eff_path, |
| dt_bias=dt_bias, |
| D_has_head_dim=D_has_head_dim, |
| learnable_init_states=learnable_init_states, |
| ssm_chunk_size=ssm_chunk_size, |
| weight_tying=weight_tying, |
| mlp_scale=mlp_scale, |
| multiple_of=multiple_of, |
| norm_eps=norm_eps, |
| init_use_depth=init_use_depth, |
| init_base_std=init_base_std, |
| init_std_factor=init_std_factor, |
| use_attn=use_attn, |
| softcap=softcap, |
| ) |
|
|
| print("Configs:") |
| for key, value in vars(configs).items(): |
| print(f" {key}: {value}") |
|
|
| model = Mamba2(configs).to(device=device, dtype=torch_dtype) |
|
|
| x = torch.randint( |
| 0, configs.vocab_size, |
| (config_data["bsz"], config_data["seq_len"]), |
| dtype=torch.long |
| ).to(device) |
|
|
| outputs = model(x) |
|
|
| print("Output shape:", outputs.shape) |
| print("Sample output:", outputs[0, 0, :10]) |
| print("Mean of Mamba output: ", outputs.mean().item()) |
| print("Stddev of Mamba output: ", outputs.std().item()) |