DifFRACT: Diffusion Feature Reconstruction and Attribution for Circuit Tracing

Trained timestep-conditioned transcoders and SAE baselines for the FLUX.1[schnell] text-to-image diffusion transformer (MM-DiT), accompanying the paper DifFRACT: Diffusion Feature Reconstruction and Attribution for Circuit Tracing (PDF).

A transcoder decomposes an MLP sublayer into a sparse linear combination of interpretable features; conditioning it on the denoising timestep lets a single transcoder track how a feature behaves across the diffusion trajectory. Substituting these transcoders into a frozen Local Replacement Model yields the attribution graphs and circuit-guided interventions studied in the paper. The code to load and use these weights is at the companion repository (GitHub).

Contents

40 PyTorch checkpoints in two folders. Every checkpoint is a state_dict for a TemporalAwareTranscoder module (the SAE baseline shares the identical architecture).

Folder Files Naming Streams Layers
temporal-aware-transcoders/ 34 transcoder_{stream}_{layer}.pt img, txt 0-15 and 18
temporal-aware-saes/ 6 sae_{stream}_{layer}.pt img, txt 6, 12, 18

The 32 transcoders for layers 0–15 (both streams) are the set analysed by the Local Replacement Model; layer 18 (and the SAEs at 6/12/18) support the sparsity–faithfulness comparison.

Model architecture

Each module maps an MLP input x to its output ŷ, conditioned on the diffusion timestep t:

  • a sinusoidal timestep embedding → 2-layer SiLU MLP → linear head producing FiLM (scale, shift);
  • modulation x_mod = x ⊙ (1 + scale) + shift;
  • a ReLU encoder z = ReLU(W_enc x_mod + b_enc) (sparse code);
  • a unit-norm linear decoder ŷ = W_dec z + b_dec.
Hyperparameter Value
Base model FLUX.1[schnell], MM-DiT, d_model = 3072
Expansion factor 16 (d_feat = 49152)
Timestep embedding dim 256
Sparsity L1, λ_img = 3e-4, λ_txt = 5e-5
Reconstruction loss variance-normalized MSE
Optimizer AdamW, lr 2e-4, weight decay 0, CosineAnnealingLR
Activation buffer / batch 1e6 / 4096
Inference steps / guidance / resolution 4 / 0 / 512×512
Training prompts yvdao/midjourney-v6

The SAE baseline is architecturally identical but autoencodes the MLP output (input = target), so its reconstruction error is directly comparable to a transcoder's.

Usage

Install the companion code (GitHub), then:

from huggingface_hub import snapshot_download
from transcoder_training.transcoder import load_transcoders

transcoders = load_transcoders(
    f"{path}/temporal-aware-transcoders",
    layers=range(16),
    d_model=3072,
    expansion_factor=16,
    time_embed_dim=256,
)

Loading an individual SAE baseline:

import torch
from transcoder_training.transcoder import TemporalAwareSAE

sae = TemporalAwareSAE(d_model=3072, expansion_factor=16, time_embed_dim=256)
sae.load_state_dict(torch.load(f"{path}/temporal-aware-saes/sae_img_12.pt", map_location="cpu"))
sae.eval()

The end-to-end pipeline (Local Replacement Model, attribution graph, intervention) is demonstrated in walkthrough.ipynb in the companion repository.

Citation

@misc{mazur2026diffractdiffusionfeaturereconstruction,
      title={DifFRACT: Diffusion Feature Reconstruction and Attribution for Circuit Tracing}, 
      author={Artyom Mazur and Nina Konovalova and Aibek Alanov},
      year={2026},
      eprint={2606.15796},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2606.15796}, 
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Artalmaz31/DifFRACT

Finetuned
(60)
this model

Dataset used to train Artalmaz31/DifFRACT

Paper for Artalmaz31/DifFRACT