File size: 5,394 Bytes
c664669 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | ---
license: mit
language:
- en
tags:
- fmri
- neuroscience
- brain
- foundation-model
- vision-transformer
- jepa
- burn
- rust
datasets:
- ukbiobank
pipeline_tag: feature-extraction
library_name: brainjepa-rs
---
# Brain-JEPA (safetensors)
Pretrained weights for **Brain-JEPA** (NeurIPS 2024, Spotlight) converted to safetensors format for use with [brainjepa-rs](https://github.com/eugenehp/brainjepa-rs).
## Model description
Brain-JEPA is a brain dynamics foundation model that maps parcellated fMRI time series (450 ROIs x T time points) to latent representations using a Vision Transformer with:
- **Brain gradient positioning** for spatial (ROI) embeddings
- **Temporal patch embedding** via 1D convolution along time
- **JEPA architecture** (Joint Embedding Predictive Architecture)
The encoder is a 12-layer ViT-Base (768-dim, 12 heads, ~86M params) pretrained on UK Biobank resting-state fMRI for 300 epochs.
## Files
| File | Description | Shape info |
|---|---|---|
| `brainjepa.safetensors` | All weights (encoder + predictor + target_encoder) | 384 tensors, ~709 MB |
| `gradient_mapping_450.csv` | Brain gradient coordinates for positional embeddings | 450 rows x 30 columns |
### Weight key structure
Keys are prefixed by component (`encoder.`, `predictor.`, `target_encoder.`):
```
encoder.patch_embed.proj.weight [768, 1, 1, 16]
encoder.blocks.{i}.norm1.weight [768]
encoder.blocks.{i}.attn.qkv.weight [2304, 768]
encoder.blocks.{i}.attn.proj.weight [768, 768]
encoder.blocks.{i}.mlp.fc1.weight [3072, 768]
encoder.blocks.{i}.mlp.fc2.weight [768, 3072]
encoder.norm.weight [768]
...
```
For inference, use `target_encoder.*` keys (EMA-smoothed weights from pretraining).
## Usage with brainjepa-rs (Rust)
```sh
# Install
git clone https://github.com/eugenehp/brainjepa-rs
cd brainjepa-rs
# Download weights from this repo
# Place brainjepa.safetensors and gradient_mapping_450.csv in data/
# Run inference (CPU)
cargo run --release --bin infer -- \
--weights data/brainjepa.safetensors \
--gradient data/gradient_mapping_450.csv \
--input data/fmri_sample.safetensors
# Run inference (GPU, Metal/Vulkan)
cargo run --release --no-default-features --features wgpu --bin infer -- \
--weights data/brainjepa.safetensors \
--gradient data/gradient_mapping_450.csv \
--input data/fmri_sample.safetensors
```
### Rust library
```rust
use brainjepa_rs::{BrainJepaEncoder, ModelConfig, DataConfig};
let (encoder, _) = BrainJepaEncoder::<B>::from_weights(
"data/brainjepa.safetensors",
"data/gradient_mapping_450.csv",
&ModelConfig::default(),
&DataConfig::default(),
&device,
)?;
let result = encoder.encode_safetensors("data/fmri.safetensors")?;
// result.embeddings: [4500, 768] float32
```
## Usage with original Python code
These weights were converted from the original PyTorch checkpoint. To use with the original code:
```python
import torch
from safetensors.torch import load_file
tensors = load_file("brainjepa.safetensors")
# Filter for target_encoder weights and strip prefix:
state_dict = {
k.removeprefix("target_encoder."): v
for k, v in tensors.items()
if k.startswith("target_encoder.")
}
model.load_state_dict(state_dict)
```
## Conversion
Weights were converted from the original PyTorch checkpoint using:
```sh
python scripts/convert_weights.py \
--input jepa-ep300.pth.tar \
--output brainjepa.safetensors
```
The conversion script strips the `module.` prefix from DDP-wrapped state dicts, converts all tensors to float32, and saves in safetensors format.
## Benchmark
Tested on Mac Mini M4 Pro (14 cores, 64 GB).
Input: `[1, 1, 450, 160]` (single sample, ViT-Base 86M params). Best-of-3 encode time.
| Backend | Encode | vs PyTorch CPU |
|---|---|---|
| Rust — NdArray + Rayon (CPU) | 28,778 ms | 0.06x |
| Rust — NdArray + Accelerate (CPU) | 21,092 ms | 0.08x |
| Python — PyTorch (CPU) | 1,782 ms | 1.0x |
| Python — PyTorch MPS (GPU) | 581 ms | 3.1x |
| **Rust — wgpu f32 / Metal (GPU)** | **83 ms** | **21.5x** |
| **Rust — wgpu f16 / Metal (GPU)** | **85 ms** | **21.0x** |
The Rust wgpu GPU backends are ~7x faster than PyTorch MPS and ~21x faster
than PyTorch CPU.

## Architecture details
| Parameter | Value |
|---|---|
| Model | ViT-Base |
| Embedding dim | 768 |
| Encoder depth | 12 layers |
| Predictor depth | 6 layers |
| Attention heads | 12 |
| Head dim | 64 |
| MLP ratio | 4x (hidden=3072) |
| Patch size | 16 (temporal) |
| Input size | 450 ROIs x 160 time points |
| Output | 4500 patches x 768 dims |
| Normalization | LayerNorm (eps=1e-6) |
| Activation | GELU |
| Pretraining | 300 epochs on UK Biobank |
| Loss | Smooth L1 (JEPA representation matching) |
| Optimizer | AdamW (lr=1e-3, warmup=40 epochs, cosine decay) |
## Source
Original paper and code:
> Zijian Dong, Ruilin Li, Yilei Wu, et al.
> **Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking.**
> NeurIPS 2024 (Spotlight). [arXiv:2409.19407](https://arxiv.org/abs/2409.19407)
- Paper: [arxiv.org/abs/2409.19407](https://arxiv.org/abs/2409.19407)
- Original code: [github.com/hzlab/Brain-JEPA](https://github.com/hzlab/Brain-JEPA)
- Rust inference: [github.com/eugenehp/brainjepa-rs](https://github.com/eugenehp/brainjepa-rs)
|