chunkable-mamba2
Custom Mamba2 model and configuration classes for 🤗 Transformers that add support for vertically chunked inference, which processes input sequences in fixed-size vertical chunks through all model layers with constant memory usage, regardless of sequence length.
What this repository provides
ChunkableMamba2Config: extendsMamba2Configwith ause_mem_eff_pathoption for the memory-efficient CUDA kernel path.ChunkableMamba2Model: extendsMamba2Modelwith a chunkable mixer and cache that correctly propagate the recurrent states across vertical chunks (simultaneousseq_idx+initial_statessupport).chunkable_mamba_split_conv1d_scan_combined: modifiedmamba_split_conv1d_scan_combinedkernel wrapper that passes cache parameters through the SSD scan so that conv and SSM states are properly initialized and exported during chunked inference.
Usage
This repository is designed to be referenced directly from Hugging Face model configs via auto_map, so that models can be loaded with trust_remote_code=True without any local installation:
"auto_map": {
"AutoConfig": "dynatrace-oss/chunkable-mamba2--configuration_chunkable_mamba2.ChunkableMamba2Config",
"AutoModel": "dynatrace-oss/chunkable-mamba2--modeling_chunkable_mamba2.ChunkableMamba2Model"
}
Models
This code was created for the following embedding models:
Requirements
Requires
transformers>=5.5.0due to a breaking change to the cache of Mamba2 introduced inv5.5.0(transformers#44950).
pip install transformers kernels einops
Open Source Integration Roadmap
Our goal is to integrate all necessary changes to simplify the adoption of vertically chunked inference for other models:
⚪ Planned | 🟡 In Progress | 🟢 Integrated
- ⚪ causal-conv1d: Enable simultaneous
seq_idx+initial_states(required for recurrent processing of chunks with left padding) - ⚪ mamba-ssm: Use
seq_idx+initial_statesinmamba_split_conv1d_scan_combinedand export final states - ⚪ kernels-community: Propagate changes in
causal-conv1dandmamba-ssmto their kernel hub equivalents in thekernels-communityrepositories - ⚪ transformers: Use updated
mamba_split_conv1d_scan_combinedwith cache params during inference (currently only used during training, not configurable, problems with left padding)
This list will be updated as integration progresses.
License
Apache-2.0