MosaicBERT-updated

An updated HuggingFace implementation of MosaicBERT (Portes et al., NeurIPS 2023) with three bugs fixed and full attn_implementation dispatch support (eager, sdpa, flash_attention_2).

This repo contains only code โ€” no weights. Load weights from any original MosaicBERT checkpoint by passing this repo as the code source (see usage below).

What changed from the original

The original mosaicml/mosaic-bert-base uses a custom Triton flash attention kernel (flash_attn_triton) that is tied to a specific GPU/Triton version and no longer works reliably with recent PyTorch. This port replaces it with the standard flash-attn library (flash_attn_varlen_qkvpacked_func) and adds SDPA support, while keeping the rest of the architecture unchanged (ALiBi, unpadding, GLU FFN, low-precision LayerNorm).

Bugs fixed vs the user-facing mosaicml/mosaic-bert-base code:

  1. attn_implementation dispatch reads config._attn_implementation (underscore prefix, set by from_pretrained) instead of config.attn_implementation (no underscore, which is always None and silently fell back to eager).
  2. extended_attention_mask is cast to hidden_states.dtype instead of torch.float32, which broke bfloat16 inference.
  3. _supports_sdpa = True and _supports_flash_attn_2 = True flags added to all three model classes so HF's dispatch machinery activates correctly.
  4. alibi_slopes cast to float() before passing to flash_attn_varlen_qkvpacked_func. from_pretrained(..., torch_dtype=bfloat16) calls model.to(bfloat16) on the whole module, which converts all floating-point tensors โ€” parameters and buffers alike. alibi_slopes is a registered buffer, so it becomes bfloat16. The self.alibi bias matrix had an explicit .to(hidden_states.dtype) cast before use, but alibi_slopes did not. The flash-attn CUDA kernel requires slopes in fp32 regardless of model dtype, so passing bfloat16 slopes causes a hard error. The .float() call is a no-op when slopes are already fp32 and prevents the crash otherwise.

Parity Verification

Hidden states and logits verified bit-for-bit identical (max abs diff = 0.00 at every layer) to the original MosaicBERT eager path (pure-PyTorch fallback) on a padded 4-sentence batch. SDPA vs eager max diff = 2.77e-05. Verified on GPU with PyTorch 2.7 / CUDA 12.9.

Architecture

MosaicBERT-Base has the same macro-architecture as BERT-base but with four modifications:

Modification Detail
Attention Flash Attention (packed QKV) via flash-attn
Positional encoding ALiBi (no position embeddings)
FFN Gated Linear Units (GeGLU)
Padding Unpadding: sequences are concatenated and processed without padding tokens
Parameter Value
Layers 12
Attention heads 12
Embedding dimension 768
Vocabulary size 30,528 (30,522 + padding to multiple of 64)
Parameters ~137M (larger than BERT-base due to GLU gating matrix)
Pretraining length 128 tokens
alibi_starting_size 1024 (pre-allocates the ALiBi bias matrix; increase for longer sequences)

Usage

Load any original MosaicBERT checkpoint using this repo for the model code:

import torch
from transformers import AutoModelForMaskedLM, BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Drop-in replacement for mosaicml/mosaic-bert-base
model = AutoModelForMaskedLM.from_pretrained(
    "mosaicml/mosaic-bert-base",
    code_revision=None,           # use trust_remote_code from this repo instead
    trust_remote_code=True,
    # point at this repo for the fixed code:
    # (see note below on how to load with this repo's code)
)

Recommended: load weights from the original checkpoint, code from this repo:

import torch
from transformers import AutoConfig, AutoModelForMaskedLM, BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Load config from original checkpoint, override auto_map to use this repo's code
config = AutoConfig.from_pretrained(
    "mosaicml/mosaic-bert-base",
    trust_remote_code=True,
    code_revision=None,
)

model = AutoModelForMaskedLM.from_pretrained(
    "mosaicml/mosaic-bert-base",
    config=config,
    trust_remote_code=True,
    # Override model code with this fixed version:
    # clone this repo locally and import directly, or use the pattern below
)
model.eval()

Simplest pattern โ€” load directly via this repo:

import torch
from transformers import AutoConfig, AutoModelForMaskedLM, BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
config = AutoConfig.from_pretrained("Taykhoom/MosaicBERT-updated", trust_remote_code=True)

# Load weights from original MosaicBERT, architecture from this repo
model = AutoModelForMaskedLM.from_pretrained(
    "mosaicml/mosaic-bert-base",
    config=config,
    trust_remote_code=True,
)
model.eval()

inputs = tokenizer(["The [MASK] sat on the mat."], return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits

Attention implementation

# SDPA (default on PyTorch >= 2.0, no extra install needed)
model = AutoModelForMaskedLM.from_pretrained(
    "mosaicml/mosaic-bert-base", config=config, trust_remote_code=True,
    attn_implementation="sdpa",
)

# Flash Attention 2 (requires: pip install flash-attn --no-build-isolation)
model = AutoModelForMaskedLM.from_pretrained(
    "mosaicml/mosaic-bert-base", config=config, trust_remote_code=True,
    attn_implementation="flash_attention_2",
)

Sequence length extrapolation via ALiBi

ALiBi has no hard sequence length limit. To run on longer sequences, increase alibi_starting_size (pre-allocates the bias matrix):

config = AutoConfig.from_pretrained("Taykhoom/MosaicBERT-updated", trust_remote_code=True)
config.alibi_starting_size = 2048

model = AutoModelForMaskedLM.from_pretrained(
    "mosaicml/mosaic-bert-base", config=config, trust_remote_code=True,
)

Original MosaicBERT checkpoints

Checkpoint Pretraining length Weights
mosaic-bert-base 128 tokens HF Hub
mosaic-bert-base-seqlen-256 256 tokens HF Hub
mosaic-bert-base-seqlen-512 512 tokens HF Hub
mosaic-bert-base-seqlen-1024 1024 tokens HF Hub
mosaic-bert-base-seqlen-2048 2048 tokens HF Hub

Citation

@article{portes2023_mosaicbert,
  title   = {MosaicBERT: A Bidirectional Encoder Optimized for Fast Pretraining},
  author  = {Portes, Jacob and Trott, Alexander R. and Havens, Sam and King, Daniel and Venigalla, Abhinav and Nadeem, Moin and Sardana, Nikhil and Khudia, Daya and Frankle, Jonathan},
  journal = {Advances in Neural Information Processing Systems},
  volume  = {36},
  year    = {2023}
}

Credits

Original MosaicBERT architecture and weights by MosaicML (now Databricks). Source: GitHub. This updated implementation was authored primarily by Claude Code and reviewed manually by Taykhoom Dalal.

License

Apache 2.0, following the original repository.

Downloads last month
75
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support