--- license: mit language: - en tags: - fmri - neuroscience - brain - foundation-model - vision-transformer - jepa - burn - rust datasets: - ukbiobank pipeline_tag: feature-extraction library_name: brainjepa-rs --- # Brain-JEPA (safetensors) Pretrained weights for **Brain-JEPA** (NeurIPS 2024, Spotlight) converted to safetensors format for use with [brainjepa-rs](https://github.com/eugenehp/brainjepa-rs). ## Model description Brain-JEPA is a brain dynamics foundation model that maps parcellated fMRI time series (450 ROIs x T time points) to latent representations using a Vision Transformer with: - **Brain gradient positioning** for spatial (ROI) embeddings - **Temporal patch embedding** via 1D convolution along time - **JEPA architecture** (Joint Embedding Predictive Architecture) The encoder is a 12-layer ViT-Base (768-dim, 12 heads, ~86M params) pretrained on UK Biobank resting-state fMRI for 300 epochs. ## Files | File | Description | Shape info | |---|---|---| | `brainjepa.safetensors` | All weights (encoder + predictor + target_encoder) | 384 tensors, ~709 MB | | `gradient_mapping_450.csv` | Brain gradient coordinates for positional embeddings | 450 rows x 30 columns | ### Weight key structure Keys are prefixed by component (`encoder.`, `predictor.`, `target_encoder.`): ``` encoder.patch_embed.proj.weight [768, 1, 1, 16] encoder.blocks.{i}.norm1.weight [768] encoder.blocks.{i}.attn.qkv.weight [2304, 768] encoder.blocks.{i}.attn.proj.weight [768, 768] encoder.blocks.{i}.mlp.fc1.weight [3072, 768] encoder.blocks.{i}.mlp.fc2.weight [768, 3072] encoder.norm.weight [768] ... ``` For inference, use `target_encoder.*` keys (EMA-smoothed weights from pretraining). ## Usage with brainjepa-rs (Rust) ```sh # Install git clone https://github.com/eugenehp/brainjepa-rs cd brainjepa-rs # Download weights from this repo # Place brainjepa.safetensors and gradient_mapping_450.csv in data/ # Run inference (CPU) cargo run --release --bin infer -- \ --weights data/brainjepa.safetensors \ --gradient data/gradient_mapping_450.csv \ --input data/fmri_sample.safetensors # Run inference (GPU, Metal/Vulkan) cargo run --release --no-default-features --features wgpu --bin infer -- \ --weights data/brainjepa.safetensors \ --gradient data/gradient_mapping_450.csv \ --input data/fmri_sample.safetensors ``` ### Rust library ```rust use brainjepa_rs::{BrainJepaEncoder, ModelConfig, DataConfig}; let (encoder, _) = BrainJepaEncoder::::from_weights( "data/brainjepa.safetensors", "data/gradient_mapping_450.csv", &ModelConfig::default(), &DataConfig::default(), &device, )?; let result = encoder.encode_safetensors("data/fmri.safetensors")?; // result.embeddings: [4500, 768] float32 ``` ## Usage with original Python code These weights were converted from the original PyTorch checkpoint. To use with the original code: ```python import torch from safetensors.torch import load_file tensors = load_file("brainjepa.safetensors") # Filter for target_encoder weights and strip prefix: state_dict = { k.removeprefix("target_encoder."): v for k, v in tensors.items() if k.startswith("target_encoder.") } model.load_state_dict(state_dict) ``` ## Conversion Weights were converted from the original PyTorch checkpoint using: ```sh python scripts/convert_weights.py \ --input jepa-ep300.pth.tar \ --output brainjepa.safetensors ``` The conversion script strips the `module.` prefix from DDP-wrapped state dicts, converts all tensors to float32, and saves in safetensors format. ## Benchmark Tested on Mac Mini M4 Pro (14 cores, 64 GB). Input: `[1, 1, 450, 160]` (single sample, ViT-Base 86M params). Best-of-3 encode time. | Backend | Encode | vs PyTorch CPU | |---|---|---| | Rust — NdArray + Rayon (CPU) | 28,778 ms | 0.06x | | Rust — NdArray + Accelerate (CPU) | 21,092 ms | 0.08x | | Python — PyTorch (CPU) | 1,782 ms | 1.0x | | Python — PyTorch MPS (GPU) | 581 ms | 3.1x | | **Rust — wgpu f32 / Metal (GPU)** | **83 ms** | **21.5x** | | **Rust — wgpu f16 / Metal (GPU)** | **85 ms** | **21.0x** | The Rust wgpu GPU backends are ~7x faster than PyTorch MPS and ~21x faster than PyTorch CPU. ![benchmark](benchmark.png) ## Architecture details | Parameter | Value | |---|---| | Model | ViT-Base | | Embedding dim | 768 | | Encoder depth | 12 layers | | Predictor depth | 6 layers | | Attention heads | 12 | | Head dim | 64 | | MLP ratio | 4x (hidden=3072) | | Patch size | 16 (temporal) | | Input size | 450 ROIs x 160 time points | | Output | 4500 patches x 768 dims | | Normalization | LayerNorm (eps=1e-6) | | Activation | GELU | | Pretraining | 300 epochs on UK Biobank | | Loss | Smooth L1 (JEPA representation matching) | | Optimizer | AdamW (lr=1e-3, warmup=40 epochs, cosine decay) | ## Source Original paper and code: > Zijian Dong, Ruilin Li, Yilei Wu, et al. > **Brain-JEPA: Brain Dynamics Foundation Model with Gradient Positioning and Spatiotemporal Masking.** > NeurIPS 2024 (Spotlight). [arXiv:2409.19407](https://arxiv.org/abs/2409.19407) - Paper: [arxiv.org/abs/2409.19407](https://arxiv.org/abs/2409.19407) - Original code: [github.com/hzlab/Brain-JEPA](https://github.com/hzlab/Brain-JEPA) - Rust inference: [github.com/eugenehp/brainjepa-rs](https://github.com/eugenehp/brainjepa-rs)