File size: 5,394 Bytes
c664669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
---
license: mit
language:
- en
tags:
- fmri
- neuroscience
- brain
- foundation-model
- vision-transformer
- jepa
- burn
- rust
datasets:
- ukbiobank
pipeline_tag: feature-extraction
library_name: brainjepa-rs
---

# Brain-JEPA (safetensors)

Pretrained weights for **Brain-JEPA** (NeurIPS 2024, Spotlight) converted to safetensors format for use with [brainjepa-rs](https://github.com/eugenehp/brainjepa-rs).

## Model description

Brain-JEPA is a brain dynamics foundation model that maps parcellated fMRI time series (450 ROIs x T time points) to latent representations using a Vision Transformer with:

- **Brain gradient positioning** for spatial (ROI) embeddings
- **Temporal patch embedding** via 1D convolution along time
- **JEPA architecture** (Joint Embedding Predictive Architecture)

The encoder is a 12-layer ViT-Base (768-dim, 12 heads, ~86M params) pretrained on UK Biobank resting-state fMRI for 300 epochs.

## Files

| File | Description | Shape info |
|---|---|---|
| `brainjepa.safetensors` | All weights (encoder + predictor + target_encoder) | 384 tensors, ~709 MB |
| `gradient_mapping_450.csv` | Brain gradient coordinates for positional embeddings | 450 rows x 30 columns |

### Weight key structure

Keys are prefixed by component (`encoder.`, `predictor.`, `target_encoder.`):

```
encoder.patch_embed.proj.weight          [768, 1, 1, 16]
encoder.blocks.{i}.norm1.weight          [768]
encoder.blocks.{i}.attn.qkv.weight       [2304, 768]
encoder.blocks.{i}.attn.proj.weight      [768, 768]
encoder.blocks.{i}.mlp.fc1.weight        [3072, 768]
encoder.blocks.{i}.mlp.fc2.weight        [768, 3072]
encoder.norm.weight                      [768]
...
```

For inference, use `target_encoder.*` keys (EMA-smoothed weights from pretraining).

## Usage with brainjepa-rs (Rust)

```sh
# Install
git clone https://github.com/eugenehp/brainjepa-rs
cd brainjepa-rs

# Download weights from this repo
# Place brainjepa.safetensors and gradient_mapping_450.csv in data/

# Run inference (CPU)
cargo run --release --bin infer -- \
    --weights data/brainjepa.safetensors \
    --gradient data/gradient_mapping_450.csv \
    --input data/fmri_sample.safetensors

# Run inference (GPU, Metal/Vulkan)
cargo run --release --no-default-features --features wgpu --bin infer -- \
    --weights data/brainjepa.safetensors \
    --gradient data/gradient_mapping_450.csv \
    --input data/fmri_sample.safetensors
```

### Rust library

```rust
use brainjepa_rs::{BrainJepaEncoder, ModelConfig, DataConfig};

let (encoder, _) = BrainJepaEncoder::<B>::from_weights(
    "data/brainjepa.safetensors",
    "data/gradient_mapping_450.csv",
    &ModelConfig::default(),
    &DataConfig::default(),
    &device,
)?;
let result = encoder.encode_safetensors("data/fmri.safetensors")?;
// result.embeddings: [4500, 768] float32
```

## Usage with original Python code

These weights were converted from the original PyTorch checkpoint. To use with the original code:

```python
import torch
from safetensors.torch import load_file

tensors = load_file("brainjepa.safetensors")
# Filter for target_encoder weights and strip prefix:
state_dict = {
    k.removeprefix("target_encoder."): v
    for k, v in tensors.items()
    if k.startswith("target_encoder.")
}
model.load_state_dict(state_dict)
```

## Conversion

Weights were converted from the original PyTorch checkpoint using:

```sh
python scripts/convert_weights.py \
    --input jepa-ep300.pth.tar \
    --output brainjepa.safetensors
```

The conversion script strips the `module.` prefix from DDP-wrapped state dicts, converts all tensors to float32, and saves in safetensors format.

## Benchmark

Tested on Mac Mini M4 Pro (14 cores, 64 GB).
Input: `[1, 1, 450, 160]` (single sample, ViT-Base 86M params). Best-of-3 encode time.

| Backend | Encode | vs PyTorch CPU |
|---|---|---|
| Rust — NdArray + Rayon (CPU) | 28,778 ms | 0.06x |
| Rust — NdArray + Accelerate (CPU) | 21,092 ms | 0.08x |
| Python — PyTorch (CPU) | 1,782 ms | 1.0x |
| Python — PyTorch MPS (GPU) | 581 ms | 3.1x |
| **Rust — wgpu f32 / Metal (GPU)** | **83 ms** | **21.5x** |
| **Rust — wgpu f16 / Metal (GPU)** | **85 ms** | **21.0x** |

The Rust wgpu GPU backends are ~7x faster than PyTorch MPS and ~21x faster
than PyTorch CPU.

![benchmark](benchmark.png)

## Architecture details

| Parameter | Value |
|---|---|
| Model | ViT-Base |
| Embedding dim | 768 |
| Encoder depth | 12 layers |
| Predictor depth | 6 layers |
| Attention heads | 12 |
| Head dim | 64 |
| MLP ratio | 4x (hidden=3072) |
| Patch size | 16 (temporal) |
| Input size | 450 ROIs x 160 time points |
| Output | 4500 patches x 768 dims |
| Normalization | LayerNorm (eps=1e-6) |
| Activation | GELU |
| Pretraining | 300 epochs on UK Biobank |
| Loss | Smooth L1 (JEPA representation matching) |
| Optimizer | AdamW (lr=1e-3, warmup=40 epochs, cosine decay) |

## Source

Original paper and code:

> Zijian Dong, Ruilin Li, Yilei Wu, et al.
> **Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking.**
> NeurIPS 2024 (Spotlight). [arXiv:2409.19407](https://arxiv.org/abs/2409.19407)

- Paper: [arxiv.org/abs/2409.19407](https://arxiv.org/abs/2409.19407)
- Original code: [github.com/hzlab/Brain-JEPA](https://github.com/hzlab/Brain-JEPA)
- Rust inference: [github.com/eugenehp/brainjepa-rs](https://github.com/eugenehp/brainjepa-rs)