GeneMamba: Foundation Model for Single-Cell Analysis

A Hugging Face compatible implementation of GeneMamba, a foundational state-space model (Mamba) designed for advanced single-cell RNA-seq analysis.

πŸ“‹ Table of Contents


Overview

GeneMamba is a state-space model (SSM) based on Mamba architecture optimized for single-cell gene expression analysis. The model:

  • Takes ranked gene sequences as input (genes sorted by expression level)
  • Outputs cell embeddings suitable for clustering, classification, and batch integration
  • Supports multiple downstream tasks including cell type annotation and next-token pretraining
  • Is compatible with Hugging Face Transformers for easy integration into existing pipelines

Key Features

βœ… Efficient Sequence Processing: SSM-based architecture with linear complexity
βœ… Cell Representation Learning: Direct cell embedding without intermediate steps
βœ… Multi-task Support: Classification, next-token pretraining, and embeddings in one model
βœ… Hugging Face Integration: Standard from_pretrained() and save_pretrained() interface
βœ… Production Ready: Pretrained checkpoints available on Hugging Face Hub


Datasets

The pretraining dataset and downstream datasets can be found in the official GeneMamba GitHub repository:

https://github.com/MineSelf2016/GeneMamba


Installation

Option 1: Install from Source

cd GeneMamba_HuggingFace
pip install -e .

Option 2: Install from PyPI (coming soon)

pip install genemamba-hf

Dependencies

  • Python >= 3.9
  • PyTorch >= 2.0
  • Transformers >= 4.40.0
  • mamba-ssm >= 2.2.0

Install all dependencies:

pip install -r requirements.txt

Quick Start

Phase 1: Extract Cell Embeddings

This is the most common use case. Extract single-cell embeddings for downstream analysis:

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel

# Load pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "mineself2016/GeneMamba",
    trust_remote_code=True
)
model = AutoModel.from_pretrained(
    "mineself2016/GeneMamba",
    trust_remote_code=True
)

# Prepare input: ranked gene sequences
# Shape: (batch_size, seq_len) with gene Ensembl IDs as token IDs
batch_size, seq_len = 8, 2048
input_ids = torch.randint(2, 25426, (batch_size, seq_len))

# Extract cell embedding
outputs = model(input_ids)
cell_embeddings = outputs.pooled_embedding  # shape: (8, 512)

print(f"Cell embeddings shape: {cell_embeddings.shape}")
# Output: Cell embeddings shape: torch.Size([8, 512])

Key Points

  • Input format: Ranked sequences of gene token IDs (genes sorted by expression descending)
  • Recommended embedding: Always use outputs.pooled_embedding for downstream tasks
  • Pooling method: Default is mean pooling over sequence (see config.embedding_pooling)
  • Sequence length: Maximum 2048; shorter sequences are auto-padded
  • Token vocabulary: Based on Ensembl Gene IDs (e.g., ENSG00000000003)

Use Cases for Cell Embeddings

  • Clustering: KMeans, Leiden, etc.
  • Visualization: UMAP, t-SNE
  • Classification: Logistic regression with frozen embeddings
  • Batch integration: Evaluate with batch correction metrics
  • Retrieval: Find similar cells or genes

Phase 2: Downstream Tasks

Use GeneMamba for cell type annotation and other sequence classification tasks:

import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset

# Load model with classification head
model = AutoModelForSequenceClassification.from_pretrained(
    "mineself2016/GeneMamba",
    num_labels=10,  # number of cell types
    trust_remote_code=True
)

# Prepare dataset
class GeneExpressionDataset(Dataset):
    def __init__(self, input_ids, labels):
        self.input_ids = input_ids
        self.labels = labels
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "labels": self.labels[idx]
        }

# Example data
X_train = torch.randint(2, 25426, (1000, 2048))
y_train = torch.randint(0, 10, (1000,))

train_dataset = GeneExpressionDataset(X_train, y_train)

# Fine-tune with Trainer
trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir="./results",
        num_train_epochs=5,
        per_device_train_batch_size=32,
        learning_rate=2e-5,
        save_strategy="epoch",
    ),
    train_dataset=train_dataset,
)

trainer.train()

Classification Variants

The model also supports:

  • Binary classification: num_labels=2
  • Multi-class: num_labels=N
  • Multi-label: Use BCEWithLogitsLoss in custom training loop
  • Regression: Modify head (custom implementation needed)

Phase 3: Train from Scratch

Train a new GeneMamba model with next-token prediction. If a checkpoint exists, resume automatically; otherwise start from scratch.

import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM, Trainer, TrainingArguments
from transformers.trainer_utils import get_last_checkpoint

tokenizer = AutoTokenizer.from_pretrained(
    "mineself2016/GeneMamba",
    trust_remote_code=True,
)

print("vocab_size:", tokenizer.vocab_size)  # 25426
print("unk/pad:", tokenizer.unk_token_id, tokenizer.pad_token_id)  # 0, 1
print("cls/mask:", tokenizer.cls_token_id, tokenizer.mask_token_id)  # None, None

# Build model config (no local modeling file import required)
config = AutoConfig.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True)
config.vocab_size = 25426
config.hidden_size = 512
config.num_hidden_layers = 24
config.max_position_embeddings = 2048
config.mamba_mode = "mean"

# Resume if checkpoint exists
output_dir = "./from_scratch_pretrain"
checkpoint_dir = Path(output_dir) / "checkpoint-last"

if checkpoint_dir.exists():
    resume_from_checkpoint = str(checkpoint_dir)
else:
    resume_from_checkpoint = get_last_checkpoint(output_dir)

if resume_from_checkpoint is not None:
    model = AutoModelForMaskedLM.from_pretrained(
        resume_from_checkpoint,
        trust_remote_code=True,
        local_files_only=True,
    )
else:
    model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True)

class NextTokenTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs["input_ids"]
        logits = model(input_ids=input_ids).logits
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device)
        loss = torch.nn.functional.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )
        return loss

trainer = NextTokenTrainer(
    model=model,
    args=TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=32,
        learning_rate=2e-5,
    ),
    train_dataset=train_dataset,
)

trainer.train(resume_from_checkpoint=resume_from_checkpoint)

Model Variants

We provide several pretrained checkpoint sizes:

Model Name Layers Hidden Size Parameters Download
GeneMamba-24l-512d 24 512 ~170M πŸ€— Hub
GeneMamba-24l-768d 24 768 ~380M πŸ€— Hub
GeneMamba-48l-512d 48 512 ~340M πŸ€— Hub
GeneMamba-48l-768d 48 768 ~750M πŸ€— Hub

All models share the same tokenizer (25,426 Ensembl Gene IDs + special tokens).


Architecture

Model Components

GeneMambaModel (Backbone)
β”œβ”€β”€ Embedding Layer          (vocab_size Γ— hidden_size)
β”œβ”€β”€ MambaMixer               (Bidirectional SSM processing)
β”‚   β”œβ”€β”€ EncoderLayer 0
β”‚   β”œβ”€β”€ EncoderLayer 1
β”‚   β”œβ”€β”€ ...
β”‚   └── EncoderLayer N-1
β”œβ”€β”€ RMSNorm                  (Layer Normalization)
└── Output: Pooled Embedding (batch_size Γ— hidden_size)

Task-Specific Heads:
β”œβ”€β”€ GeneMambaForSequenceClassification
β”‚   └── Linear(hidden_size β†’ num_labels)
β”œβ”€β”€ GeneMambaForMaskedLM
β”‚   └── Linear(hidden_size β†’ vocab_size)

Key Design Choices

  • Bidirectional Mamba Block: Bidirectional Mamba enables significant improvement in gene rank reconstruction task
  • Pooling Strategy: Bidirectional Mamba with multiple aggregation modes (mean/sum/concat/gate)
  • Regularization: Dropout on classification head
  • Activation: No explicit activation (Mamba uses internal gating)

Important Notes ⚠️

Input Format

GeneMamba expects a very specific input format:

  1. Each cell is represented as a ranked sequence of genes
  2. Genes should be sorted by expression value in descending order
  3. Use Ensembl Gene IDs as tokens (e.g., ENSG00000000003)
  4. Sequences are padded/truncated to max_position_embeddings (default 2048)

Example preparation:

import numpy as np
import scanpy as sc

# Load scRNA-seq data
adata = sc.read_h5ad("data.h5ad")

# For each cell, rank genes by expression
gene_ids = []
for cell_idx in range(adata.n_obs):
    expression = adata.X[cell_idx].toarray().flatten()
    ranked_indices = np.argsort(-expression)  # Descending order
    ranked_gene_ids = [gene_id_mapping[idx] for idx in ranked_indices[:2048]]
    gene_ids.append(ranked_gene_ids)

# Convert to token IDs
input_ids = tokenizer(gene_ids, return_tensors="pt", padding=True)["input_ids"]

Limitations

  • Gene vocabulary: Only genes in Ensembl (25,426 total) can be directly tokenized
  • Sequence order: Expects ranked order; random order will degrade performance
  • Batch size: Larger batches (32-64) recommended for better convergence
  • GPU memory: Base model needs ~10GB for batch_size=32; larger variants need more

Examples

See the examples/ directory for complete scripts:

  • 1_extract_embeddings.py - Extract cell embeddings
  • 2_finetune_classification.py - Cell type annotation
  • 3_pretrain_from_scratch.py - Train from scratch (next-token + optional resume)

Citation

If you find GeneMamba is useful in your research, please cite:

@article{qi2025genemamba,
  title={GeneMamba: An Efficient and Effective Foundation Model on Single Cell Data},
  author={Qi, Cong and Fang, Hanzhang and Jiang, Siqi and Song, Xun and Hu, Tianxing and Zhi, Wei},
  journal={arXiv preprint arXiv:2504.16956},
  year={2026}
}

Troubleshooting

trust_remote_code=True Error

This is expected for custom models. Either:

  1. Set trust_remote_code=True (safe if loading from official repo)
  2. Or use sys.path.insert(0, '.') if loading local code

Old Cached Code / Shape Mismatch

If you still see old loading errors after an update, force refresh files from Hub:

from transformers import AutoModel
model = AutoModel.from_pretrained(
    "mineself2016/GeneMamba",
    trust_remote_code=True,
    force_download=True,
)

You can also clear local cache if needed:

rm -rf ~/.cache/huggingface/hub/models--mineself2016--GeneMamba

Out of Memory (OOM)

Reduce batch size:

args = TrainingArguments(
    per_device_train_batch_size=8,  # Reduce from 32
    ...
)

Tokenizer Not Found

Make sure tokenizer files are in the same directory:

GeneMamba_repo/
β”œβ”€β”€ config.json
β”œβ”€β”€ model.safetensors
β”œβ”€β”€ tokenizer.json           ← Required
β”œβ”€β”€ tokenizer_config.json    ← Required
└── ...

Last Updated: March 2026

Downloads last month
710
Safetensors
Model size
65.7M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for mineself2016/GeneMamba