ESM-2 for Apple MLX
From-scratch MLX implementation of Meta's ESM-2 protein language model, optimized for Apple Silicon.
Supports masked language modeling and residue–residue contact prediction across all six ESM-2 model sizes.
GitHub: josephjojoe/esm-mlx
Available Weights
| File | Model | Layers | Hidden Dim | Heads | Parameters |
|---|---|---|---|---|---|
esm2_t6_8M_UR50D.safetensors |
ESM-2 8M | 6 | 320 | 20 | 8M |
esm2_t12_35M_UR50D.safetensors |
ESM-2 35M | 12 | 480 | 20 | 35M |
esm2_t30_150M_UR50D.safetensors |
ESM-2 150M | 30 | 640 | 20 | 150M |
esm2_t33_650M_UR50D.safetensors |
ESM-2 650M | 33 | 1280 | 20 | 650M |
esm2_t36_3B_UR50D.safetensors |
ESM-2 3B | 36 | 2560 | 40 | 3B |
esm2_t48_15B_UR50D.safetensors |
ESM-2 15B | 48 | 5120 | 40 | 15B |
Weights are converted from the official PyTorch checkpoints to safetensors format.
Usage
pip install "esm-mlx @ git+https://github.com/josephjojoe/esm-mlx.git"
import mlx.core as mx
from esm_mlx import ESM2, Tokenizer
model = ESM2.from_pretrained("esm2_t33_650M_UR50D") # auto-downloads weights
tok = Tokenizer()
tokens = tok.encode("MKTAYIAKQRQISFVKSHFSRQLE")
out = model(tokens)
logits = out["logits"] # (1, seq_len, vocab_size)
Benchmarks
ESM-2 650M on M2 Pro (16 GB), MLX 0.30.6 vs PyTorch 2.10.0 MPS. Median latency over 50 iterations after 10 warmup passes.
Float32
| Batch | Seq Len | MLX | PyTorch MPS | Speedup |
|---|---|---|---|---|
| 1 | 64 | 43.9 ms | 43.7 ms | 1.00x |
| 1 | 128 | 63.0 ms | 66.9 ms | 1.06x |
| 1 | 256 | 104.1 ms | 129.4 ms | 1.24x |
| 1 | 512 | 190.4 ms | 242.1 ms | 1.27x |
| 1 | 1024 | 378.5 ms | 511.1 ms | 1.35x |
| 4 | 256 | 340.7 ms | 462.8 ms | 1.36x |
| 4 | 512 | 670.0 ms | 1005.5 ms | 1.50x |
| 4 | 1024 | 1409.3 ms | 2439.9 ms | 1.73x |
| 8 | 256 | 646.6 ms | 930.1 ms | 1.44x |
| 8 | 512 | 1305.0 ms | 2057.8 ms | 1.58x |
| 8 | 1024 | 2783.9 ms | 4935.6 ms | 1.77x |
Float16
| Batch | Seq Len | MLX | PyTorch MPS | Speedup |
|---|---|---|---|---|
| 8 | 64 | 149.6 ms | 218.2 ms | 1.46x |
| 8 | 128 | 273.3 ms | 414.8 ms | 1.52x |
| 8 | 256 | 522.1 ms | 868.0 ms | 1.66x |
| 8 | 512 | 1039.9 ms | 2012.4 ms | 1.94x |
| 8 | 1024 | 2186.3 ms | 5266.7 ms | 2.41x |
| 16 | 64 | 272.2 ms | 412.0 ms | 1.51x |
| 16 | 128 | 514.4 ms | 807.5 ms | 1.57x |
| 16 | 256 | 1006.8 ms | 1738.5 ms | 1.73x |
| 16 | 512 | 2051.9 ms | 4123.9 ms | 2.01x |
| 16 | 1024 | 4349.1 ms | 14759.5 ms | 3.39x |
| 32 | 64 | 509.0 ms | 797.1 ms | 1.57x |
| 32 | 128 | 989.5 ms | 1626.6 ms | 1.64x |
| 32 | 256 | 1985.4 ms | 3573.3 ms | 1.80x |
| 32 | 512 | 4081.2 ms | 8295.5 ms | 2.03x |
| 32 | 1024 | 8678.4 ms | OOM | — |
FP16 widens the gap significantly. The 3.39x result at batch=16, seq=1024 likely reflects PyTorch MPS thrashing near its memory ceiling — it OOMs entirely one step later at batch=32. MLX's unified-memory allocation avoids this cliff and continues to scale linearly up to batch=192 at seq=1024, sustaining ~3,784 tok/s on 16 GB.
License
MIT — same as the original ESM-2 weights.
- Downloads last month
- 14
Quantized