GeneMamba: Foundation Model for Single-Cell Analysis
A Hugging Face compatible implementation of GeneMamba, a foundational state-space model (Mamba) designed for advanced single-cell RNA-seq analysis.
π Table of Contents
Overview
GeneMamba is a state-space model (SSM) based on Mamba architecture optimized for single-cell gene expression analysis. The model:
- Takes ranked gene sequences as input (genes sorted by expression level)
- Outputs cell embeddings suitable for clustering, classification, and batch integration
- Supports multiple downstream tasks including cell type annotation and next-token pretraining
- Is compatible with Hugging Face Transformers for easy integration into existing pipelines
Key Features
β
Efficient Sequence Processing: SSM-based architecture with linear complexity
β
Cell Representation Learning: Direct cell embedding without intermediate steps
β
Multi-task Support: Classification, next-token pretraining, and embeddings in one model
β
Hugging Face Integration: Standard from_pretrained() and save_pretrained() interface
β
Production Ready: Pretrained checkpoints available on Hugging Face Hub
Datasets
The pretraining dataset and downstream datasets can be found in the official GeneMamba GitHub repository:
https://github.com/MineSelf2016/GeneMamba
Installation
Option 1: Install from Source
cd GeneMamba_HuggingFace
pip install -e .
Option 2: Install from PyPI (coming soon)
pip install genemamba-hf
Dependencies
- Python >= 3.9
- PyTorch >= 2.0
- Transformers >= 4.40.0
- mamba-ssm >= 2.2.0
Install all dependencies:
pip install -r requirements.txt
Quick Start
Phase 1: Extract Cell Embeddings
This is the most common use case. Extract single-cell embeddings for downstream analysis:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
# Load pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"mineself2016/GeneMamba",
trust_remote_code=True
)
model = AutoModel.from_pretrained(
"mineself2016/GeneMamba",
trust_remote_code=True
)
# Prepare input: ranked gene sequences
# Shape: (batch_size, seq_len) with gene Ensembl IDs as token IDs
batch_size, seq_len = 8, 2048
input_ids = torch.randint(2, 25426, (batch_size, seq_len))
# Extract cell embedding
outputs = model(input_ids)
cell_embeddings = outputs.pooled_embedding # shape: (8, 512)
print(f"Cell embeddings shape: {cell_embeddings.shape}")
# Output: Cell embeddings shape: torch.Size([8, 512])
Key Points
- Input format: Ranked sequences of gene token IDs (genes sorted by expression descending)
- Recommended embedding: Always use
outputs.pooled_embeddingfor downstream tasks - Pooling method: Default is mean pooling over sequence (see
config.embedding_pooling) - Sequence length: Maximum 2048; shorter sequences are auto-padded
- Token vocabulary: Based on Ensembl Gene IDs (e.g.,
ENSG00000000003)
Use Cases for Cell Embeddings
- Clustering: KMeans, Leiden, etc.
- Visualization: UMAP, t-SNE
- Classification: Logistic regression with frozen embeddings
- Batch integration: Evaluate with batch correction metrics
- Retrieval: Find similar cells or genes
Phase 2: Downstream Tasks
Use GeneMamba for cell type annotation and other sequence classification tasks:
import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset
# Load model with classification head
model = AutoModelForSequenceClassification.from_pretrained(
"mineself2016/GeneMamba",
num_labels=10, # number of cell types
trust_remote_code=True
)
# Prepare dataset
class GeneExpressionDataset(Dataset):
def __init__(self, input_ids, labels):
self.input_ids = input_ids
self.labels = labels
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"labels": self.labels[idx]
}
# Example data
X_train = torch.randint(2, 25426, (1000, 2048))
y_train = torch.randint(0, 10, (1000,))
train_dataset = GeneExpressionDataset(X_train, y_train)
# Fine-tune with Trainer
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir="./results",
num_train_epochs=5,
per_device_train_batch_size=32,
learning_rate=2e-5,
save_strategy="epoch",
),
train_dataset=train_dataset,
)
trainer.train()
Classification Variants
The model also supports:
- Binary classification:
num_labels=2 - Multi-class:
num_labels=N - Multi-label: Use
BCEWithLogitsLossin custom training loop - Regression: Modify head (custom implementation needed)
Phase 3: Train from Scratch
Train a new GeneMamba model with next-token prediction. If a checkpoint exists, resume automatically; otherwise start from scratch.
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM, Trainer, TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
tokenizer = AutoTokenizer.from_pretrained(
"mineself2016/GeneMamba",
trust_remote_code=True,
)
print("vocab_size:", tokenizer.vocab_size) # 25426
print("unk/pad:", tokenizer.unk_token_id, tokenizer.pad_token_id) # 0, 1
print("cls/mask:", tokenizer.cls_token_id, tokenizer.mask_token_id) # None, None
# Build model config (no local modeling file import required)
config = AutoConfig.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True)
config.vocab_size = 25426
config.hidden_size = 512
config.num_hidden_layers = 24
config.max_position_embeddings = 2048
config.mamba_mode = "mean"
# Resume if checkpoint exists
output_dir = "./from_scratch_pretrain"
checkpoint_dir = Path(output_dir) / "checkpoint-last"
if checkpoint_dir.exists():
resume_from_checkpoint = str(checkpoint_dir)
else:
resume_from_checkpoint = get_last_checkpoint(output_dir)
if resume_from_checkpoint is not None:
model = AutoModelForMaskedLM.from_pretrained(
resume_from_checkpoint,
trust_remote_code=True,
local_files_only=True,
)
else:
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True)
class NextTokenTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs["input_ids"]
logits = model(input_ids=input_ids).logits
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device)
loss = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
return loss
trainer = NextTokenTrainer(
model=model,
args=TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=32,
learning_rate=2e-5,
),
train_dataset=train_dataset,
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
Model Variants
We provide several pretrained checkpoint sizes:
| Model Name | Layers | Hidden Size | Parameters | Download |
|---|---|---|---|---|
GeneMamba-24l-512d |
24 | 512 | ~170M | π€ Hub |
GeneMamba-24l-768d |
24 | 768 | ~380M | π€ Hub |
GeneMamba-48l-512d |
48 | 512 | ~340M | π€ Hub |
GeneMamba-48l-768d |
48 | 768 | ~750M | π€ Hub |
All models share the same tokenizer (25,426 Ensembl Gene IDs + special tokens).
Architecture
Model Components
GeneMambaModel (Backbone)
βββ Embedding Layer (vocab_size Γ hidden_size)
βββ MambaMixer (Bidirectional SSM processing)
β βββ EncoderLayer 0
β βββ EncoderLayer 1
β βββ ...
β βββ EncoderLayer N-1
βββ RMSNorm (Layer Normalization)
βββ Output: Pooled Embedding (batch_size Γ hidden_size)
Task-Specific Heads:
βββ GeneMambaForSequenceClassification
β βββ Linear(hidden_size β num_labels)
βββ GeneMambaForMaskedLM
β βββ Linear(hidden_size β vocab_size)
Key Design Choices
- Bidirectional Mamba Block: Bidirectional Mamba enables significant improvement in gene rank reconstruction task
- Pooling Strategy: Bidirectional Mamba with multiple aggregation modes (mean/sum/concat/gate)
- Regularization: Dropout on classification head
- Activation: No explicit activation (Mamba uses internal gating)
Important Notes β οΈ
Input Format
GeneMamba expects a very specific input format:
- Each cell is represented as a ranked sequence of genes
- Genes should be sorted by expression value in descending order
- Use Ensembl Gene IDs as tokens (e.g.,
ENSG00000000003) - Sequences are padded/truncated to max_position_embeddings (default 2048)
Example preparation:
import numpy as np
import scanpy as sc
# Load scRNA-seq data
adata = sc.read_h5ad("data.h5ad")
# For each cell, rank genes by expression
gene_ids = []
for cell_idx in range(adata.n_obs):
expression = adata.X[cell_idx].toarray().flatten()
ranked_indices = np.argsort(-expression) # Descending order
ranked_gene_ids = [gene_id_mapping[idx] for idx in ranked_indices[:2048]]
gene_ids.append(ranked_gene_ids)
# Convert to token IDs
input_ids = tokenizer(gene_ids, return_tensors="pt", padding=True)["input_ids"]
Limitations
- Gene vocabulary: Only genes in Ensembl (25,426 total) can be directly tokenized
- Sequence order: Expects ranked order; random order will degrade performance
- Batch size: Larger batches (32-64) recommended for better convergence
- GPU memory: Base model needs ~10GB for batch_size=32; larger variants need more
Examples
See the examples/ directory for complete scripts:
1_extract_embeddings.py- Extract cell embeddings2_finetune_classification.py- Cell type annotation3_pretrain_from_scratch.py- Train from scratch (next-token + optional resume)
Citation
If you find GeneMamba is useful in your research, please cite:
@article{qi2025genemamba,
title={GeneMamba: An Efficient and Effective Foundation Model on Single Cell Data},
author={Qi, Cong and Fang, Hanzhang and Jiang, Siqi and Song, Xun and Hu, Tianxing and Zhi, Wei},
journal={arXiv preprint arXiv:2504.16956},
year={2026}
}
Troubleshooting
trust_remote_code=True Error
This is expected for custom models. Either:
- Set
trust_remote_code=True(safe if loading from official repo) - Or use
sys.path.insert(0, '.')if loading local code
Old Cached Code / Shape Mismatch
If you still see old loading errors after an update, force refresh files from Hub:
from transformers import AutoModel
model = AutoModel.from_pretrained(
"mineself2016/GeneMamba",
trust_remote_code=True,
force_download=True,
)
You can also clear local cache if needed:
rm -rf ~/.cache/huggingface/hub/models--mineself2016--GeneMamba
Out of Memory (OOM)
Reduce batch size:
args = TrainingArguments(
per_device_train_batch_size=8, # Reduce from 32
...
)
Tokenizer Not Found
Make sure tokenizer files are in the same directory:
GeneMamba_repo/
βββ config.json
βββ model.safetensors
βββ tokenizer.json β Required
βββ tokenizer_config.json β Required
βββ ...
Last Updated: March 2026
- Downloads last month
- 710