Step 3.7 Flash β€” MLX Adapter (text + vision)

First MLX-native runtime adapter for stepfun-ai/Step-3.7-Flash β€” Step Robotics' multimodal model based on the step3p5 MoE backbone (Qwen3.6-A3B family) with a 47-layer PerceptionEncoder vision tower.

Lets you run Step 3.7 Flash on Apple Silicon through mlx-vlm (multimodal, recommended) or mlx_lm (text-only, fallback).

This is the adapter code only β€” no model weights are shipped here. You download or quantize Step-3.7-Flash separately (instructions below).

What this includes

Path Purpose
mlx_vlm/models/step3p7/ mlx-vlm adapter β€” 6 files, multimodal (text + vision)
mlx_lm/models/step3p7.py mlx-lm adapter β€” text-only
model_dir_patches/chat_template.jinja patched Jinja template (multi-crop image expansion + reasoning_effort=none) β€” drop into your model dir
scripts/extract_vision_weights.py reproducer for the vision-shard pre-transpose step (needed for MLX NHWC convention)

Architecture

Step 3.7 Flash = step3p5 backbone (45-layer Qwen3.6-A3B MoE, 12 full-attn + 33 sliding-window=512 layers) + 47-layer PerceptionEncoder ViT + 2 stride-2 Conv2d downsamplers + Linear(6144β†’4096) projector.

Vision pipeline:

  1. ImagePatcher (pure PIL/numpy port of upstream ImagePatcher): square-pad extreme aspect ratios, resize so max(W, H) ≀ 3024, then either single-view (long ≀ 728) or sliding-window 504Β² sub-patches.
  2. Vision encoder dual-pass: global 728Β² β†’ 169 tokens (52Γ—52 grid β†’ 2Γ— downsample β†’ 13Β² = 169), each 504Β² sub-patch β†’ 81 tokens (36Γ—36 grid β†’ 9Β² = 81). Same ViT weights, two grid sizes via interpolated positional embeddings.
  3. Placeholder layout per image (chat-template + processor): [<patch_start>{81Γ—<im_patch>}<patch_end> + opt <patch_newline>] Γ— N + <im_start>{169Γ—<im_patch>}<im_end>. encode_image returns features in matching order; merge via mx.cumsum + mx.where.

Bugs caught + fixed during the port

  1. MLX nn.Upsample applies 2-tuple scale_factor to dims 1+2 of the input β€” not the last 2 spatial dims like PyTorch's F.interpolate does on NCHW. So 504Β²β†’36Γ—36 positional embedding interpolation needs NHWC layout natively (no NCHW transpose). The naive port silently shrank the 1536 hidden dim instead of the spatial 52Β² dims, producing a tensor with 1,989,936 elements that couldn't reshape to (1296, 1536).

  2. mx.load + mx.save_safetensors round-trip writes zeros for lazy values. When extracting vision weights from upstream bf16 shards, you must mx.eval(v) each tensor before storing it in the output dict, or the safetensors file ends up shape-correct but full of zeros. The format=mlx metadata is required for mlx-vlm to load the file correctly.

  3. Vision feature cache must be ignored on multi-crop path. Engines that pre-cache vision features via model.encode_image(pixel_values) only get the global 169 tokens; multi-crop prompts have 169 + NΓ—81 placeholder positions, so the cached features no longer match. The naive zero-pad fallback then trashes every patch slot with zeros. Fix: detect patch_pixel_values in kwargs and recompute everything when present.

  4. Vision weight provenance matters more than vision architecture. Step3-VL-10B and Step-3.7-Flash share the same StepRoboticsVisionEncoder architecture, but their projector weights were trained against different text-side LMs (Qwen3-8B vs step3p5 MoE) β€” using Step3-VL-10B's vision weights with Step-3.7-Flash's LM gives mathematically correct vision features that the LM can't decode (model collapses to "white" / "a man in a suit" for every image). Always pull model-vit-00001/00002.safetensors from the matching upstream model.

Installation

Drop the adapter into your mlx_vlm (and optionally mlx_lm) install:

# Discover the install paths from Python β€” works for pip / uv / conda /
# bundled frameworks alike.
MLX_VLM_DIR=$(python -c "import os, mlx_vlm; print(os.path.dirname(mlx_vlm.__file__))")
MLX_LM_DIR=$(python  -c "import os, mlx_lm;  print(os.path.dirname(mlx_lm.__file__))")

cp -r mlx_vlm/models/step3p7  "$MLX_VLM_DIR/models/"
cp    mlx_lm/models/step3p7.py "$MLX_LM_DIR/models/"

You also need a 1-line patch to mlx_vlm/prompt_utils.py:

# In MESSAGE_FORMATS dict, add:
"step3p7": MessageFormat.LIST_WITH_IMAGE_FIRST,

Preparing the model weights

You will need quantized text shards + 1 vision shard with pre-transposed Conv2d kernels.

# 1. Get Step-3.7-Flash bf16 from upstream (very large)
hf download stepfun-ai/Step-3.7-Flash --local-dir Step-3.7-Flash

# 2. Quantize text shards to 4-bit (or use an existing community quant)
# e.g. https://huggingface.co/<some>/Step-3.7-Flash-4bit/

# 3. Extract vision weights with the NHWC Conv2d transpose applied
python scripts/extract_vision_weights.py \
  --src Step-3.7-Flash \
  --dst Step-3.7-Flash-4bit/model-00023-of-00023.safetensors

# 4. Patch chat_template.jinja into the model dir
cp model_dir_patches/chat_template.jinja Step-3.7-Flash-4bit/

# 5. Update model.safetensors.index.json to include the new vision shard keys.
# (Index update logic depends on your text quant β€” see `processing_step3.py` upstream for the canonical layout.)

Usage (via mlx-vlm)

from mlx_vlm import load, apply_chat_template, generate
from PIL import Image

model, processor = load("Step-3.7-Flash-4bit")
img = Image.open("photo.jpg")
prompt = apply_chat_template(processor, model.config, [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "What is in this image?"},
    ]}
], num_images=1)
output = generate(model, processor, prompt, image=img, max_tokens=200)
print(output)

Or call any OpenAI-compatible MLX engine with model="Step-3.7-Flash-4bit" and image content blocks.

Tested on

  • Mac Studio M3 Ultra 512 GB
  • mlx==0.30.x, mlx-vlm==0.5.0
  • Text decode: ~44 tok/s warm
  • Vision: 4-corner color/shape recognition on 2000Γ—1500 images, sun detection on 1920Γ—1080 landscape β€” all correct
  • Long-context: 182K-token needle-in-haystack retrieval β€” correct answer, ~11 min wall (Step 3.7 4-bit is slow-but-correct vs the same generation's Huihui 8-bit Opus-distill at ~4.5 min)

Multi-Token Prediction (MTP) speculative decoding

Step 3.7 ships 3 MTP heads (vLLM Step3p5MTP) trained against the upstream BF16 backbone. We port the MTPModule + MTPLayer classes to MLX (mlx_vlm/models/step3p7/language.py) and wire mtp_forward / make_mtp_cache / return_hidden into both LanguageModel and the outer Model so the oMLX MTP draft/verify cycle (qwen35 PR 990-derived batch_generator patch) engages automatically when mtp_enabled: True is set in the model settings.

Status: works. 73-80% accept rate, lossless at temp=0, ~15% net speedup.

Config (Step-3.7-Flash-4bit, M3 Ultra, 300 tok warm, 5 iters) Overall TPS Accept
MTP off (turboquant_kv: True) 45.6 n/a
MTP on (per-MTP-layer lm_head + BF16 embed + 7 norm shifts) 52.6 73-80%

Per-prompt: short_code 60, physics 57, Chinese translation 49 tok/s. Greedy output is byte-identical to MTP-off (verified on 4 prompts).

What it took to get here

The hard part was identifying which weight-loading conventions Hikari07jp's BF16 extraction expected. With the wrong conventions, accept rate stuck at 0-4% (slowdown vs baseline); with the right conventions, 75%.

  1. Per-MTP-layer shared_head.output is NOT tied to backbone lm_head β€” biggest finding. Each of the 3 MTP layers in the upstream BF16 ships its own shared_head.output weight (different per-layer means 0.000008/0.000029/0.000034). We were initially calling the backbone's 4-bit quantized lm_head from mtp_forward, which produced ~0% useful draft tokens. Per-MTP-layer BF16 head jumped accept from 2-4% to 70-80%. Detected by inspecting model.layers.{45,46,47}.transformer.shared_head.output in the BF16 source and noticing the three rows aren't equal.
  2. MTP-side embed_tokens is also separate from backbone β€” vLLM Step3p5AMultiTokenPredictor keeps its own VocabParallelEmbedding. We add an nn.Embedding directly on MTPModule and route the BF16 model.embed_tokens.weight from the MTP shard there. (This alone didn't move the needle, but it removes the mixed-precision drift from mtp_forward's input lookup.)
  3. +1 norm shift on extract β€” vLLM GemmaRMSNorm.forward_native does weight = stored + 1.0; rms_norm(x, weight, eps). MLX ZeroCenteredRMSNorm is plain mx.fast.rms_norm(x, weight, eps) (no +1 baked in despite the class name). We pre-add 1 to the 7 Gemma norms in the MTP shard (enorm, hnorm, mtp_block.input_layernorm, mtp_block.post_attention_layernorm, mtp_block.self_attn.q_norm, mtp_block.self_attn.k_norm, shared_head_norm) at offline extract time.
  4. Pre-norm hidden state passed to MTP β€” LanguageModel.__call__ returns the residual stream before body.norm when called with return_hidden=True. vLLM Step3p5Model.forward() does the same. Empirically pre-norm gives 73-80% accept; post-norm gave 1-2%.

The mismatched-precision hypothesis (4-bit backbone vs BF16 MTP) wasn't the bottleneck β€” BF16 embed_tokens didn't move accept on its own. The hidden state precision wasn't either. The bottleneck was using the wrong lm_head.

How to use

The MTP path activates automatically when these settings are set on Step-3.7-Flash-4bit (or your equivalent model id) in your engine's model settings file:

{
  "mtp_enabled": true,
  "turboquant_kv_enabled": false
}

mtp_enabled and turboquant_kv_enabled are mutually exclusive per oMLX (TurboQuant patches the attention path that MTP relies on). Bench results above are with TurboQuant OFF. Restart oMLX after changing settings: restart your engine process. Verify activation in the log:

[engine] Native MTP patch applied for ... (model_type=step3p7, active)
[batch_generator] MTP path activated for uid=N (model has mtp_forward, batch=1)
[batch_generator] MTP[N] finish=length tokens=200 cycles=117 accept=82/117 (70.1%) ...

To reproduce the shard rewrite from scratch:

hf download Hikari07jp/Step-3.7-Flash-MTP-draft --local-dir /tmp/mtp-src
python scripts/rewrite_mtp_shard.py \
  /tmp/mtp-src/model.safetensors \
  Step-3.7-Flash-4bit/model-00024-of-00024.safetensors
# Then update model.safetensors.index.json so the 52 new keys point
# at the new shard filename (49 layer + 3 shared_head_output entries).

Acknowledgements

  • Hikari07jp/Step-3.7-Flash-MTP-draft β€” extracted BF16 MTP-draft layers from upstream stepfun-ai/Step-3.7-Flash; the same weights run at ~80% accept in vLLM.
  • oMLX β€” patches/mlx_lm_mtp/ ports mlx-lm PR 990's MTP draft+verify batch_generator dispatch to oMLX's continuous-batching scheduler. We add step3p5 / step3p7 to its compatible list.
  • mlx-optiq/Gemma 4 spec dec post β€” the structurally identical 0%β†’3%β†’33% accept-rate fix narrative for Gemma 4 spec dec on MLX was a useful template for diagnosing this one.

To experiment, set mtp_enabled: True and turboquant_kv_enabled: False (they're mutually exclusive per oMLX) in your engine's model settings file for Step-3.7-Flash-4bit, restart oMLX, and check grep "MTP\[" <your-engine-log> for accept rate. To run the shard rewrite yourself:

hf download Hikari07jp/Step-3.7-Flash-MTP-draft --local-dir /tmp/mtp-src
python scripts/rewrite_mtp_shard.py \
  /tmp/mtp-src/model.safetensors \
  Step-3.7-Flash-4bit/model-00024-of-00024.safetensors
# Then update model.safetensors.index.json so the 48 new keys point
# at the new shard filename.

The MLX_LM_MTP patch from oMLX 0.3.10+ exposes its own _is_mtp_compatible allowlist; ours patches that list to recognise step3p5 and step3p7 model_types (see patches/mlx_lm_mtp_compat.diff for the one-line addition).

License

Adapter code: Apache-2.0 (see LICENSE). Upstream Step-3.7-Flash weights are licensed separately by Step Robotics β€” see their model card before redistributing.

Credits

Architecture port + bug catches were done end-to-end across 7 sessions of MLX debugging on Apple Silicon, with PyTorch reference parity tests used to localize the four bugs above. References that helped (none of them validate end-to-end vision on Step-3.7-Flash specifically):

Downloads last month
140
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for anerjy/Step-3.7-Flash-MLX-adapter

Finetuned
(6)
this model