ESM-2 for Apple MLX

From-scratch MLX implementation of Meta's ESM-2 protein language model, optimized for Apple Silicon.

Supports masked language modeling and residue–residue contact prediction across all six ESM-2 model sizes.

GitHub: josephjojoe/esm-mlx

Available Weights

File Model Layers Hidden Dim Heads Parameters
esm2_t6_8M_UR50D.safetensors ESM-2 8M 6 320 20 8M
esm2_t12_35M_UR50D.safetensors ESM-2 35M 12 480 20 35M
esm2_t30_150M_UR50D.safetensors ESM-2 150M 30 640 20 150M
esm2_t33_650M_UR50D.safetensors ESM-2 650M 33 1280 20 650M
esm2_t36_3B_UR50D.safetensors ESM-2 3B 36 2560 40 3B
esm2_t48_15B_UR50D.safetensors ESM-2 15B 48 5120 40 15B

Weights are converted from the official PyTorch checkpoints to safetensors format.

Usage

pip install "esm-mlx @ git+https://github.com/josephjojoe/esm-mlx.git"
import mlx.core as mx
from esm_mlx import ESM2, Tokenizer

model = ESM2.from_pretrained("esm2_t33_650M_UR50D")  # auto-downloads weights
tok = Tokenizer()

tokens = tok.encode("MKTAYIAKQRQISFVKSHFSRQLE")
out = model(tokens)
logits = out["logits"]  # (1, seq_len, vocab_size)

Benchmarks

ESM-2 650M on M2 Pro (16 GB), MLX 0.30.6 vs PyTorch 2.10.0 MPS. Median latency over 50 iterations after 10 warmup passes.

Float32

Batch Seq Len MLX PyTorch MPS Speedup
1 64 43.9 ms 43.7 ms 1.00x
1 128 63.0 ms 66.9 ms 1.06x
1 256 104.1 ms 129.4 ms 1.24x
1 512 190.4 ms 242.1 ms 1.27x
1 1024 378.5 ms 511.1 ms 1.35x
4 256 340.7 ms 462.8 ms 1.36x
4 512 670.0 ms 1005.5 ms 1.50x
4 1024 1409.3 ms 2439.9 ms 1.73x
8 256 646.6 ms 930.1 ms 1.44x
8 512 1305.0 ms 2057.8 ms 1.58x
8 1024 2783.9 ms 4935.6 ms 1.77x

Float16

Batch Seq Len MLX PyTorch MPS Speedup
8 64 149.6 ms 218.2 ms 1.46x
8 128 273.3 ms 414.8 ms 1.52x
8 256 522.1 ms 868.0 ms 1.66x
8 512 1039.9 ms 2012.4 ms 1.94x
8 1024 2186.3 ms 5266.7 ms 2.41x
16 64 272.2 ms 412.0 ms 1.51x
16 128 514.4 ms 807.5 ms 1.57x
16 256 1006.8 ms 1738.5 ms 1.73x
16 512 2051.9 ms 4123.9 ms 2.01x
16 1024 4349.1 ms 14759.5 ms 3.39x
32 64 509.0 ms 797.1 ms 1.57x
32 128 989.5 ms 1626.6 ms 1.64x
32 256 1985.4 ms 3573.3 ms 1.80x
32 512 4081.2 ms 8295.5 ms 2.03x
32 1024 8678.4 ms OOM

FP16 widens the gap significantly. The 3.39x result at batch=16, seq=1024 likely reflects PyTorch MPS thrashing near its memory ceiling — it OOMs entirely one step later at batch=32. MLX's unified-memory allocation avoids this cliff and continues to scale linearly up to batch=192 at seq=1024, sustaining ~3,784 tok/s on 16 GB.

License

MIT — same as the original ESM-2 weights.

Downloads last month
14
MLX
Hardware compatibility
Log In to add your hardware

Quantized

Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support