DifFRACT: Diffusion Feature Reconstruction and Attribution for Circuit Tracing
Trained timestep-conditioned transcoders and SAE baselines for the FLUX.1[schnell] text-to-image diffusion transformer (MM-DiT), accompanying the paper DifFRACT: Diffusion Feature Reconstruction and Attribution for Circuit Tracing (PDF).
A transcoder decomposes an MLP sublayer into a sparse linear combination of interpretable features; conditioning it on the denoising timestep lets a single transcoder track how a feature behaves across the diffusion trajectory. Substituting these transcoders into a frozen Local Replacement Model yields the attribution graphs and circuit-guided interventions studied in the paper. The code to load and use these weights is at the companion repository (GitHub).
Contents
40 PyTorch checkpoints in two folders. Every checkpoint is a state_dict for a TemporalAwareTranscoder module (the SAE baseline shares the identical architecture).
| Folder | Files | Naming | Streams | Layers |
|---|---|---|---|---|
temporal-aware-transcoders/ |
34 | transcoder_{stream}_{layer}.pt |
img, txt |
0-15 and 18 |
temporal-aware-saes/ |
6 | sae_{stream}_{layer}.pt |
img, txt |
6, 12, 18 |
The 32 transcoders for layers 0–15 (both streams) are the set analysed by the Local Replacement Model; layer 18 (and the SAEs at 6/12/18) support the sparsity–faithfulness comparison.
Model architecture
Each module maps an MLP input x to its output ŷ, conditioned on the diffusion timestep t:
- a sinusoidal timestep embedding → 2-layer SiLU MLP → linear head producing FiLM
(scale, shift); - modulation
x_mod = x ⊙ (1 + scale) + shift; - a ReLU encoder
z = ReLU(W_enc x_mod + b_enc)(sparse code); - a unit-norm linear decoder
ŷ = W_dec z + b_dec.
| Hyperparameter | Value |
|---|---|
| Base model | FLUX.1[schnell], MM-DiT, d_model = 3072 |
| Expansion factor | 16 (d_feat = 49152) |
| Timestep embedding dim | 256 |
| Sparsity | L1, λ_img = 3e-4, λ_txt = 5e-5 |
| Reconstruction loss | variance-normalized MSE |
| Optimizer | AdamW, lr 2e-4, weight decay 0, CosineAnnealingLR |
| Activation buffer / batch | 1e6 / 4096 |
| Inference steps / guidance / resolution | 4 / 0 / 512×512 |
| Training prompts | yvdao/midjourney-v6 |
The SAE baseline is architecturally identical but autoencodes the MLP output (input = target), so its reconstruction error is directly comparable to a transcoder's.
Usage
Install the companion code (GitHub), then:
from huggingface_hub import snapshot_download
from transcoder_training.transcoder import load_transcoders
transcoders = load_transcoders(
f"{path}/temporal-aware-transcoders",
layers=range(16),
d_model=3072,
expansion_factor=16,
time_embed_dim=256,
)
Loading an individual SAE baseline:
import torch
from transcoder_training.transcoder import TemporalAwareSAE
sae = TemporalAwareSAE(d_model=3072, expansion_factor=16, time_embed_dim=256)
sae.load_state_dict(torch.load(f"{path}/temporal-aware-saes/sae_img_12.pt", map_location="cpu"))
sae.eval()
The end-to-end pipeline (Local Replacement Model, attribution graph, intervention) is demonstrated in walkthrough.ipynb in the companion repository.
Citation
@misc{mazur2026diffractdiffusionfeaturereconstruction,
title={DifFRACT: Diffusion Feature Reconstruction and Attribution for Circuit Tracing},
author={Artyom Mazur and Nina Konovalova and Aibek Alanov},
year={2026},
eprint={2606.15796},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2606.15796},
}
- Base model:
black-forest-labs/FLUX.1-schnell - License: Apache-2.0
Model tree for Artalmaz31/DifFRACT
Base model
black-forest-labs/FLUX.1-schnell