Daankular/models / Wan2GP /preprocessing /sam3 /model_builder.py
Daankular's picture
download
raw
49.7 kB
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
import os
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from iopath.common.file_io import g_pathmgr
from mmgp import offload
from .model.decoder import (
DecoupledTransformerDecoderLayerv2,
SimpleRoPEAttention,
TransformerDecoder,
TransformerDecoderLayer,
TransformerDecoderLayerv2,
TransformerEncoderCrossAttention,
TransformerEncoderDecoupledCrossAttention,
)
from .model.encoder import TransformerEncoderFusion, TransformerEncoderLayer
from .model.geometry_encoders import SequenceGeometryEncoder
from .model.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
from .model.memory import (
CXBlock,
SimpleFuser,
SimpleMaskDownSampler,
SimpleMaskEncoder,
)
from .model.model_misc import (
DotProductScoring,
MLP,
MultiheadAttentionWrapper as MultiheadAttention,
TransformerWrapper,
)
from .model.multiplex_utils import MultiplexController
from .model.necks import Sam3DualViTDetNeck, Sam3TriViTDetNeck
from .model.position_encoding import PositionEmbeddingSine
from .model.sam1_task_predictor import SAM3InteractiveImagePredictor
from .model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU
from .model.sam3_tracking_predictor import Sam3TrackerPredictor
from .model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity
from .model.sam3_video_predictor import Sam3VideoPredictorMultiGPU
from .model.text_encoder_ve import VETextEncoder
from .model.tokenizer_ve import SimpleTokenizer
from .model.video_tracking_multiplex import VideoTrackingDynamicMultiplex
from .model.vitdet import ViT
from .model.vl_combiner import SAM3VLBackbone, SAM3VLBackboneTri, TriHeadVisionOnly
from .model.device_utils import get_accelerator_device
from .sam.transformer import RoPEAttention
# Setup TensorFloat-32 for Ampere GPUs if available
def _setup_tf32() -> None:
"""Enable TensorFloat-32 for Ampere GPUs if available."""
if torch.cuda.is_available():
device_props = torch.cuda.get_device_properties(0)
if device_props.major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
class _PackageResources:
@staticmethod
def resource_filename(package, resource):
return os.fspath(Path(__file__).resolve().parent / resource)
pkg_resources = _PackageResources()
@contextmanager
def _default_dtype(dtype: torch.dtype):
previous_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(previous_dtype)
def _device_context(device):
return torch.device(device) if device is not None else nullcontext()
def _remap_checkpoint_key(key: str) -> str:
if key.startswith("sam3_model."):
return "detector." + key[len("sam3_model.") :]
if key.startswith("sam2_predictor."):
return "tracker." + key[len("sam2_predictor.") :]
return key
def _keep_checkpoint_key(key: str, include_prefixes=None, exclude_prefixes=None) -> bool:
if include_prefixes is not None and not key.startswith(tuple(include_prefixes)):
return False
if exclude_prefixes is not None and key.startswith(tuple(exclude_prefixes)):
return False
return True
def _preprocess_sam3_state_dict(
model: nn.Module,
include_prefixes=None,
exclude_prefixes=None,
strip_prefix=None,
):
def preprocess(state_dict, quantization_map=None, tied_weights_map=None):
if "model" in state_dict and isinstance(state_dict["model"], dict):
state_dict = state_dict["model"]
filtered = {}
for key, value in state_dict.items():
new_key = _remap_checkpoint_key(key)
if not _keep_checkpoint_key(new_key, include_prefixes, exclude_prefixes):
continue
if strip_prefix is not None and new_key.startswith(strip_prefix):
new_key = new_key[len(strip_prefix) :]
filtered[new_key] = value
for key, value in model.state_dict().items():
if key not in filtered and torch.is_tensor(value) and not value.is_meta:
filtered[key] = value
return filtered
return preprocess
def _load_sam3_safetensors_model(
model: nn.Module,
checkpoint_path: str,
*,
include_prefixes=None,
exclude_prefixes=None,
strip_prefix=None,
) -> None:
if checkpoint_path is None:
raise FileNotFoundError("SAM3.1 requires a bf16 safetensors checkpoint.")
if Path(checkpoint_path).suffix.lower() != ".safetensors":
raise ValueError(f"SAM3.1 checkpoints must be bf16 safetensors files, got: {checkpoint_path}")
offload.load_model_data(
model,
checkpoint_path,
writable_tensors=False,
preprocess_sd=_preprocess_sam3_state_dict(
model, include_prefixes, exclude_prefixes, strip_prefix
),
default_dtype=torch.bfloat16,
)
def _refresh_rope_buffers_(module: nn.Module, device: torch.device, dtype: torch.dtype) -> None:
for submodule in module.modules():
if getattr(submodule, "use_rope", False) and getattr(submodule, "use_rope_real", False) and hasattr(submodule, "_setup_rope_freqs"):
freqs = getattr(submodule, "freqs_cis_real", None)
if torch.is_tensor(freqs) and freqs.is_meta:
submodule._setup_rope_freqs()
submodule.freqs_cis_real = submodule.freqs_cis_real.to(device=device, dtype=dtype)
submodule.freqs_cis_imag = submodule.freqs_cis_imag.to(device=device, dtype=dtype)
elif getattr(submodule, "use_rope_real", False) and hasattr(submodule, "compute_cis") and hasattr(submodule, "freqs_cis_real"):
freqs = getattr(submodule, "freqs_cis_real", None)
if not torch.is_tensor(freqs) or not freqs.is_meta:
continue
side = int(freqs.shape[0] ** 0.5)
real, imag = submodule.compute_cis(end_x=side, end_y=side, device=device)
submodule.freqs_cis_real = real.to(dtype=dtype)
submodule.freqs_cis_imag = imag.to(dtype=dtype)
def _refresh_text_encoder_buffers_(module: nn.Module, device: torch.device, dtype: torch.dtype) -> None:
for submodule in module.modules():
attn_mask = getattr(submodule, "attn_mask", None)
if torch.is_tensor(attn_mask) and attn_mask.is_meta and hasattr(submodule, "build_causal_mask"):
submodule.attn_mask = submodule.build_causal_mask().to(device=device, dtype=dtype)
def _create_position_encoding(precompute_resolution=None):
"""Create position encoding for visual backbone."""
return PositionEmbeddingSine(
num_pos_feats=256,
normalize=True,
scale=None,
temperature=10000,
precompute_resolution=precompute_resolution,
)
def _create_vit_backbone(compile_mode=None, use_fa3=False, use_rope_real=True):
"""Create ViT backbone for visual feature extraction."""
return ViT(
img_size=1008,
pretrain_img_size=336,
patch_size=14,
embed_dim=1024,
depth=32,
num_heads=16,
mlp_ratio=4.625,
norm_layer="LayerNorm",
drop_path_rate=0.1,
qkv_bias=True,
use_abs_pos=True,
tile_abs_pos=True,
global_att_blocks=(7, 15, 23, 31),
rel_pos_blocks=(),
use_rope=True,
use_interp_rope=True,
window_size=24,
pretrain_use_cls_token=True,
retain_cls_token=False,
ln_pre=True,
ln_post=False,
return_interm_layers=False,
bias_patch_embed=False,
compile_mode=compile_mode,
use_fa3=use_fa3,
use_rope_real=use_rope_real,
)
def _create_vit_neck(position_encoding, vit_backbone, enable_inst_interactivity=False):
"""Create ViT neck for feature pyramid."""
return Sam3DualViTDetNeck(
position_encoding=position_encoding,
d_model=256,
scale_factors=[4.0, 2.0, 1.0, 0.5],
trunk=vit_backbone,
add_sam2_neck=enable_inst_interactivity,
)
def _create_vl_backbone(vit_neck, text_encoder):
"""Create visual-language backbone."""
return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1)
def _create_transformer_encoder(use_fa3=False) -> TransformerEncoderFusion:
"""Create transformer encoder with its layer."""
encoder_layer = TransformerEncoderLayer(
activation="relu",
d_model=256,
dim_feedforward=2048,
dropout=0.1,
pos_enc_at_attn=True,
pos_enc_at_cross_attn_keys=False,
pos_enc_at_cross_attn_queries=False,
pre_norm=True,
self_attention=MultiheadAttention(
num_heads=8,
dropout=0.1,
embed_dim=256,
batch_first=True,
use_fa3=use_fa3,
),
cross_attention=MultiheadAttention(
num_heads=8,
dropout=0.1,
embed_dim=256,
batch_first=True,
use_fa3=use_fa3,
),
)
encoder = TransformerEncoderFusion(
layer=encoder_layer,
num_layers=6,
d_model=256,
num_feature_levels=1,
frozen=False,
use_act_checkpoint=True,
add_pooled_text_to_img_feat=False,
pool_text_with_mask=True,
)
return encoder
def _create_transformer_decoder(use_fa3=False) -> TransformerDecoder:
"""Create transformer decoder with its layer."""
decoder_layer = TransformerDecoderLayer(
activation="relu",
d_model=256,
dim_feedforward=2048,
dropout=0.1,
cross_attention=MultiheadAttention(
num_heads=8,
dropout=0.1,
embed_dim=256,
use_fa3=use_fa3,
),
n_heads=8,
use_text_cross_attention=True,
)
decoder = TransformerDecoder(
layer=decoder_layer,
num_layers=6,
num_queries=200,
return_intermediate=True,
box_refine=True,
num_o2m_queries=0,
dac=True,
boxRPB="log",
d_model=256,
frozen=False,
interaction_layer=None,
dac_use_selfatt_ln=True,
resolution=1008,
stride=14,
use_act_checkpoint=True,
presence_token=True,
)
return decoder
def _create_dot_product_scoring():
"""Create dot product scoring module."""
prompt_mlp = MLP(
input_dim=256,
hidden_dim=2048,
output_dim=256,
num_layers=2,
dropout=0.1,
residual=True,
out_norm=nn.LayerNorm(256),
)
return DotProductScoring(d_model=256, d_proj=256, prompt_mlp=prompt_mlp)
def _create_segmentation_head(compile_mode=None, use_fa3=False):
"""Create segmentation head with pixel decoder."""
pixel_decoder = PixelDecoder(
num_upsampling_stages=3,
interpolation_mode="nearest",
hidden_dim=256,
compile_mode=compile_mode,
)
cross_attend_prompt = MultiheadAttention(
num_heads=8,
dropout=0,
embed_dim=256,
use_fa3=use_fa3,
)
segmentation_head = UniversalSegmentationHead(
hidden_dim=256,
upsampling_stages=3,
aux_masks=False,
presence_head=False,
dot_product_scorer=None,
act_ckpt=True,
cross_attend_prompt=cross_attend_prompt,
pixel_decoder=pixel_decoder,
)
return segmentation_head
def _create_geometry_encoder():
"""Create geometry encoder with all its components."""
# Create position encoding for geometry encoder
geo_pos_enc = _create_position_encoding()
# Create CX block for fuser
cx_block = CXBlock(
dim=256,
kernel_size=7,
padding=3,
layer_scale_init_value=1.0e-06,
use_dwconv=True,
)
# Create geometry encoder layer
geo_layer = TransformerEncoderLayer(
activation="relu",
d_model=256,
dim_feedforward=2048,
dropout=0.1,
pos_enc_at_attn=False,
pre_norm=True,
self_attention=MultiheadAttention(
num_heads=8,
dropout=0.1,
embed_dim=256,
batch_first=False,
),
pos_enc_at_cross_attn_queries=False,
pos_enc_at_cross_attn_keys=True,
cross_attention=MultiheadAttention(
num_heads=8,
dropout=0.1,
embed_dim=256,
batch_first=False,
),
)
# Create geometry encoder
input_geometry_encoder = SequenceGeometryEncoder(
pos_enc=geo_pos_enc,
encode_boxes_as_points=False,
points_direct_project=True,
points_pool=True,
points_pos_enc=True,
boxes_direct_project=True,
boxes_pool=True,
boxes_pos_enc=True,
d_model=256,
num_layers=3,
layer=geo_layer,
use_act_ckpt=True,
add_cls=True,
add_post_encode_proj=True,
)
return input_geometry_encoder
def _create_sam3_model(
backbone,
transformer,
input_geometry_encoder,
segmentation_head,
dot_prod_scoring,
inst_interactive_predictor,
eval_mode,
):
"""Create the SAM3 image model."""
common_params = {
"backbone": backbone,
"transformer": transformer,
"input_geometry_encoder": input_geometry_encoder,
"segmentation_head": segmentation_head,
"num_feature_levels": 1,
"o2m_mask_predict": True,
"dot_prod_scoring": dot_prod_scoring,
"use_instance_query": False,
"multimask_output": True,
"inst_interactive_predictor": inst_interactive_predictor,
}
matcher = None
if not eval_mode:
from .train.matcher import BinaryHungarianMatcherV2
matcher = BinaryHungarianMatcherV2(
focal=True,
cost_class=2.0,
cost_bbox=5.0,
cost_giou=2.0,
alpha=0.25,
gamma=2,
stable=False,
)
common_params["matcher"] = matcher
model = Sam3Image(**common_params)
return model
def _create_tracker_maskmem_backbone():
"""Create the SAM3 Tracker memory encoder."""
# Position encoding for mask memory backbone
position_encoding = PositionEmbeddingSine(
num_pos_feats=64,
normalize=True,
scale=None,
temperature=10000,
precompute_resolution=1008,
)
# Mask processing components
mask_downsampler = SimpleMaskDownSampler(
kernel_size=3, stride=2, padding=1, interpol_size=[1152, 1152]
)
cx_block_layer = CXBlock(
dim=256,
kernel_size=7,
padding=3,
layer_scale_init_value=1.0e-06,
use_dwconv=True,
)
fuser = SimpleFuser(layer=cx_block_layer, num_layers=2)
maskmem_backbone = SimpleMaskEncoder(
out_dim=64,
position_encoding=position_encoding,
mask_downsampler=mask_downsampler,
fuser=fuser,
)
return maskmem_backbone
def _create_tracker_transformer():
"""Create the SAM3 Tracker transformer components."""
# Self attention
self_attention = RoPEAttention(
embedding_dim=256,
num_heads=1,
downsample_rate=1,
dropout=0.1,
rope_theta=10000.0,
feat_sizes=[72, 72],
use_fa3=False,
use_rope_real=True,
)
# Cross attention
cross_attention = RoPEAttention(
embedding_dim=256,
num_heads=1,
downsample_rate=1,
dropout=0.1,
kv_in_dim=64,
rope_theta=10000.0,
feat_sizes=[72, 72],
rope_k_repeat=True,
use_fa3=False,
use_rope_real=True,
)
# Encoder layer
encoder_layer = TransformerDecoderLayerv2(
cross_attention_first=False,
activation="relu",
dim_feedforward=2048,
dropout=0.1,
pos_enc_at_attn=False,
pre_norm=True,
self_attention=self_attention,
d_model=256,
pos_enc_at_cross_attn_keys=True,
pos_enc_at_cross_attn_queries=False,
cross_attention=cross_attention,
)
# Encoder
encoder = TransformerEncoderCrossAttention(
remove_cross_attention_layers=[],
batch_first=True,
d_model=256,
frozen=False,
pos_enc_at_input=True,
layer=encoder_layer,
num_layers=4,
use_act_checkpoint=False,
)
# Transformer wrapper
transformer = TransformerWrapper(
encoder=encoder,
decoder=None,
d_model=256,
)
return transformer
def build_tracker(
apply_temporal_disambiguation: bool, with_backbone: bool = False, compile_mode=None
) -> Sam3TrackerPredictor:
"""
Build the SAM3 Tracker module for video tracking.
Returns:
Sam3TrackerPredictor: Wrapped SAM3 Tracker module
"""
# Create model components
maskmem_backbone = _create_tracker_maskmem_backbone()
transformer = _create_tracker_transformer()
backbone = None
if with_backbone:
vision_backbone = _create_vision_backbone(compile_mode=compile_mode)
backbone = SAM3VLBackbone(scalp=1, visual=vision_backbone, text=None)
# Create the Tracker module
model = Sam3TrackerPredictor(
image_size=1008,
num_maskmem=7,
backbone=backbone,
backbone_stride=14,
transformer=transformer,
maskmem_backbone=maskmem_backbone,
# SAM parameters
multimask_output_in_sam=True,
# Evaluation
forward_backbone_per_frame_for_eval=True,
trim_past_non_cond_mem_for_eval=False,
# Multimask
multimask_output_for_tracking=True,
multimask_min_pt_num=0,
multimask_max_pt_num=1,
# Additional settings
always_start_from_first_ann_frame=False,
# Mask overlap
non_overlap_masks_for_mem_enc=False,
non_overlap_masks_for_output=False,
max_cond_frames_in_attn=4,
offload_output_to_cpu_for_eval=False,
# SAM decoder settings
sam_mask_decoder_extra_args={
"dynamic_multimask_via_stability": True,
"dynamic_multimask_stability_delta": 0.05,
"dynamic_multimask_stability_thresh": 0.98,
},
clear_non_cond_mem_around_input=True,
fill_hole_area=0,
use_memory_selection=apply_temporal_disambiguation,
)
return model
def _create_text_encoder(bpe_path: str, init_device="cpu") -> VETextEncoder:
"""Create SAM3 text encoder."""
with _device_context(init_device):
tokenizer = SimpleTokenizer(bpe_path=bpe_path)
return VETextEncoder(
tokenizer=tokenizer,
d_model=256,
width=1024,
heads=16,
layers=24,
)
def _create_vision_backbone(
compile_mode=None, enable_inst_interactivity=True
) -> Sam3DualViTDetNeck:
"""Create SAM3 visual backbone with ViT and neck."""
# Position encoding
position_encoding = _create_position_encoding(precompute_resolution=1008)
# ViT backbone
vit_backbone: ViT = _create_vit_backbone(compile_mode=compile_mode)
vit_neck: Sam3DualViTDetNeck = _create_vit_neck(
position_encoding,
vit_backbone,
enable_inst_interactivity=enable_inst_interactivity,
)
# Visual neck
return vit_neck
def _create_sam3_transformer(
has_presence_token: bool = True, use_fa3: bool = False
) -> TransformerWrapper:
"""Create SAM3 transformer encoder and decoder."""
encoder: TransformerEncoderFusion = _create_transformer_encoder(use_fa3=use_fa3)
decoder: TransformerDecoder = _create_transformer_decoder(use_fa3=use_fa3)
return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
def _load_checkpoint(model, checkpoint_path):
"""Load model checkpoint from file."""
with g_pathmgr.open(checkpoint_path, "rb") as f:
ckpt = torch.load(f, map_location="cpu", weights_only=True)
if "model" in ckpt and isinstance(ckpt["model"], dict):
ckpt = ckpt["model"]
sam3_image_ckpt = {
k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k
}
if model.inst_interactive_predictor is not None:
sam3_image_ckpt.update(
{
k.replace("tracker.", "inst_interactive_predictor.model."): v
for k, v in ckpt.items()
if "tracker" in k
}
)
missing_keys, _ = model.load_state_dict(sam3_image_ckpt, strict=False)
if len(missing_keys) > 0:
print(
f"loaded {checkpoint_path} and found "
f"missing and/or unexpected keys:\n{missing_keys=}"
)
def _setup_device_and_mode(model, device, eval_mode):
"""Setup model device and evaluation mode."""
device = torch.device(device)
if device.type != "cpu":
model = model.to(device=device)
if eval_mode:
model.eval()
return model
def build_sam3_image_model(
bpe_path=None,
device=None,
eval_mode=True,
checkpoint_path=None,
load_from_HF=True,
enable_segmentation=True,
enable_inst_interactivity=False,
compile=False,
):
"""
Build SAM3 image model
Args:
bpe_path: Path to the BPE tokenizer vocabulary
device: Device to load the model on ('cuda' or 'cpu')
eval_mode: Whether to set the model to evaluation mode
checkpoint_path: Optional path to model checkpoint
enable_segmentation: Whether to enable segmentation head
enable_inst_interactivity: Whether to enable instance interactivity (SAM 1 task)
compile_mode: To enable compilation, set to "default"
Returns:
A SAM3 image model
"""
_setup_tf32()
if device is None:
device = get_accelerator_device()
if bpe_path is None:
bpe_path = pkg_resources.resource_filename(
"sam3", "assets/bpe_simple_vocab_16e6.txt.gz"
)
# Create visual components
compile_mode = "default" if compile else None
vision_encoder = _create_vision_backbone(
compile_mode=compile_mode, enable_inst_interactivity=enable_inst_interactivity
)
# Create text components
text_encoder = _create_text_encoder(bpe_path)
# Create visual-language backbone
backbone = _create_vl_backbone(vision_encoder, text_encoder)
# Create transformer components
transformer = _create_sam3_transformer()
# Create dot product scoring
dot_prod_scoring = _create_dot_product_scoring()
# Create segmentation head if enabled
segmentation_head = (
_create_segmentation_head(compile_mode=compile_mode)
if enable_segmentation
else None
)
# Create geometry encoder
input_geometry_encoder = _create_geometry_encoder()
if enable_inst_interactivity:
sam3_pvs_base = build_tracker(apply_temporal_disambiguation=False)
inst_predictor = SAM3InteractiveImagePredictor(sam3_pvs_base)
else:
inst_predictor = None
# Create the SAM3 model
model = _create_sam3_model(
backbone,
transformer,
input_geometry_encoder,
segmentation_head,
dot_prod_scoring,
inst_predictor,
eval_mode,
)
if load_from_HF and checkpoint_path is None:
checkpoint_path = download_ckpt_from_hf(version="sam3")
# Load checkpoint if provided
if checkpoint_path is not None:
_load_checkpoint(model, checkpoint_path)
# Setup device and mode
model = _setup_device_and_mode(model, device, eval_mode)
return model
def download_ckpt_from_hf(version="sam3"):
"""Download model checkpoint from HuggingFace Hub.
Args:
version: "sam3" or "sam3.1"
"""
if version == "sam3.1":
raise FileNotFoundError("SAM3.1 uses the local sam3.1_multiplex_bf16.safetensors checkpoint.")
else:
repo_id = "facebook/sam3"
ckpt_name = "sam3.pt"
cfg_name = "config.json"
_ = hf_hub_download(repo_id=repo_id, filename=cfg_name)
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=ckpt_name)
return checkpoint_path
def build_sam3_video_model(
checkpoint_path: Optional[str] = None,
load_from_HF=True,
bpe_path: Optional[str] = None,
has_presence_token: bool = True,
geo_encoder_use_img_cross_attn: bool = True,
strict_state_dict_loading: bool = True,
apply_temporal_disambiguation: bool = True,
device=None,
compile=False,
) -> Sam3VideoInferenceWithInstanceInteractivity:
"""
Build SAM3 dense tracking model.
Args:
checkpoint_path: Optional path to checkpoint file
bpe_path: Path to the BPE tokenizer file
Returns:
Sam3VideoInferenceWithInstanceInteractivity: The instantiated dense tracking model
"""
_setup_tf32()
if device is None:
device = get_accelerator_device()
if bpe_path is None:
bpe_path = pkg_resources.resource_filename(
"sam3", "assets/bpe_simple_vocab_16e6.txt.gz"
)
# Build Tracker module
tracker = build_tracker(apply_temporal_disambiguation=apply_temporal_disambiguation)
# Build Detector components
visual_neck = _create_vision_backbone()
text_encoder = _create_text_encoder(bpe_path)
backbone = SAM3VLBackbone(scalp=1, visual=visual_neck, text=text_encoder)
transformer = _create_sam3_transformer(has_presence_token=has_presence_token)
segmentation_head: UniversalSegmentationHead = _create_segmentation_head()
input_geometry_encoder = _create_geometry_encoder()
# Create main dot product scoring
main_dot_prod_mlp = MLP(
input_dim=256,
hidden_dim=2048,
output_dim=256,
num_layers=2,
dropout=0.1,
residual=True,
out_norm=nn.LayerNorm(256),
)
main_dot_prod_scoring = DotProductScoring(
d_model=256, d_proj=256, prompt_mlp=main_dot_prod_mlp
)
# Build Detector module
detector = Sam3ImageOnVideoMultiGPU(
num_feature_levels=1,
backbone=backbone,
transformer=transformer,
segmentation_head=segmentation_head,
semantic_segmentation_head=None,
input_geometry_encoder=input_geometry_encoder,
use_early_fusion=True,
use_dot_prod_scoring=True,
dot_prod_scoring=main_dot_prod_scoring,
supervise_joint_box_scores=has_presence_token,
)
# Build the main SAM3 video model
if apply_temporal_disambiguation:
model = Sam3VideoInferenceWithInstanceInteractivity(
detector=detector,
tracker=tracker,
score_threshold_detection=0.5,
assoc_iou_thresh=0.1,
det_nms_thresh=0.1,
new_det_thresh=0.7,
hotstart_delay=15,
hotstart_unmatch_thresh=8,
hotstart_dup_thresh=8,
suppress_unmatched_only_within_hotstart=True,
min_trk_keep_alive=-1,
max_trk_keep_alive=30,
init_trk_keep_alive=30,
suppress_overlapping_based_on_recent_occlusion_threshold=0.7,
suppress_det_close_to_boundary=False,
fill_hole_area=16,
recondition_every_nth_frame=16,
masklet_confirmation_enable=False,
decrease_trk_keep_alive_for_empty_masklets=False,
image_size=1008,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
compile_model=compile,
)
else:
# a version without any heuristics for ablation studies
model = Sam3VideoInferenceWithInstanceInteractivity(
detector=detector,
tracker=tracker,
score_threshold_detection=0.5,
assoc_iou_thresh=0.1,
det_nms_thresh=0.1,
new_det_thresh=0.7,
hotstart_delay=0,
hotstart_unmatch_thresh=0,
hotstart_dup_thresh=0,
suppress_unmatched_only_within_hotstart=True,
min_trk_keep_alive=-1,
max_trk_keep_alive=30,
init_trk_keep_alive=30,
suppress_overlapping_based_on_recent_occlusion_threshold=0.7,
suppress_det_close_to_boundary=False,
fill_hole_area=16,
recondition_every_nth_frame=0,
masklet_confirmation_enable=False,
decrease_trk_keep_alive_for_empty_masklets=False,
image_size=1008,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
compile_model=compile,
)
# Load checkpoint if provided
if load_from_HF and checkpoint_path is None:
checkpoint_path = download_ckpt_from_hf(version="sam3")
if checkpoint_path is not None:
with g_pathmgr.open(checkpoint_path, "rb") as f:
ckpt = torch.load(f, map_location="cpu", weights_only=True)
if "model" in ckpt and isinstance(ckpt["model"], dict):
ckpt = ckpt["model"]
missing_keys, unexpected_keys = model.load_state_dict(
ckpt, strict=strict_state_dict_loading
)
if missing_keys:
print(f"Missing keys: {missing_keys}")
if unexpected_keys:
print(f"Unexpected keys: {unexpected_keys}")
model.to(device=device)
return model
def build_sam3_video_predictor(*model_args, gpus_to_use=None, **model_kwargs):
return Sam3VideoPredictorMultiGPU(
*model_args, gpus_to_use=gpus_to_use, **model_kwargs
)
def _create_multiplex_maskmem_backbone(multiplex_count=16):
"""Create the multiplex memory encoder with per-object mask channels."""
position_encoding = PositionEmbeddingSine(
num_pos_feats=256,
normalize=True,
scale=None,
temperature=10000,
precompute_resolution=1008,
)
mask_downsampler = SimpleMaskDownSampler(
kernel_size=3,
stride=2,
padding=1,
interpol_size=[1152, 1152],
multiplex_count=multiplex_count,
starting_out_chan=4,
input_channel_multiplier=2,
)
cx_block_layer = CXBlock(
dim=256,
kernel_size=7,
padding=3,
layer_scale_init_value=1.0e-06,
use_dwconv=True,
)
fuser = SimpleFuser(layer=cx_block_layer, num_layers=2)
maskmem_backbone = SimpleMaskEncoder(
out_dim=256,
position_encoding=position_encoding,
mask_downsampler=mask_downsampler,
fuser=fuser,
)
return maskmem_backbone
def _create_multiplex_transformer(use_fa3=False, use_rope_real=True):
"""Create the decoupled transformer for multiplex memory attention."""
self_attention_rope = SimpleRoPEAttention(
d_model=256,
num_heads=8,
dropout_p=0.1,
rope_theta=10000.0,
feat_sizes=[72, 72],
use_fa3=use_fa3,
use_rope_real=use_rope_real,
)
cross_attention_rope = SimpleRoPEAttention(
d_model=256,
num_heads=8,
dropout_p=0.1,
rope_theta=10000.0,
feat_sizes=[72, 72],
rope_k_repeat=True,
use_fa3=use_fa3,
use_rope_real=use_rope_real,
)
encoder_layer = DecoupledTransformerDecoderLayerv2(
activation="gelu",
d_model=256,
num_heads=8,
dropout=0.1,
dim_feedforward=2048,
pos_enc_at_attn=False,
pre_norm=True,
pos_enc_at_cross_attn_keys=True,
pos_enc_at_cross_attn_queries=False,
self_attention_rope=self_attention_rope,
cross_attention_rope=cross_attention_rope,
)
encoder = TransformerEncoderDecoupledCrossAttention(
d_model=256,
frozen=False,
pos_enc_at_input=True,
use_image_in_output=False,
layer=encoder_layer,
num_layers=4,
use_act_checkpoint=False,
batch_first=True,
)
transformer = TransformerWrapper(
encoder=encoder,
decoder=None,
d_model=256,
)
return transformer
def _create_multiplex_tri_backbone(
compile_mode=None, use_fa3=False, use_rope_real=True, init_device="cpu"
):
"""Create the TriHead vision backbone for multiplex model."""
with _device_context(init_device):
position_encoding = _create_position_encoding(precompute_resolution=1008)
vit_backbone = _create_vit_backbone(
compile_mode=compile_mode, use_fa3=use_fa3, use_rope_real=use_rope_real
)
return Sam3TriViTDetNeck(
trunk=vit_backbone,
position_encoding=position_encoding,
d_model=256,
scale_factors=[4.0, 2.0, 1.0],
)
def build_sam3_multiplex_video_model(
checkpoint_path: Optional[str] = None,
load_from_HF=True,
multiplex_count: int = 16,
use_fa3: bool = False,
use_rope_real: bool = True,
trim_past_non_cond_mem_for_eval: bool = False,
strict_state_dict_loading: bool = True,
device=None,
init_device="cpu",
move_to_device: bool = True,
compile=False,
):
"""
Build SAM3 multiplex video tracking model.
Args:
checkpoint_path: Optional path to checkpoint file
multiplex_count: Number of objects per multiplex bucket
use_fa3: Whether to use FlashAttention 3
use_rope_real: Whether to use real-valued RoPE (for compile compat)
strict_state_dict_loading: Whether to use strict state dict loading
device: Device to place model on
compile: Whether to compile model components
Returns:
VideoTrackingDynamicMultiplex: The instantiated multiplex tracking model
"""
_setup_tf32()
if load_from_HF and checkpoint_path is None:
raise FileNotFoundError(
"SAM3.1 uses the local sam3.1_multiplex_bf16.safetensors checkpoint."
)
if checkpoint_path is not None:
raise ValueError(
"Standalone SAM3.1 tracker checkpoint loading is not supported; "
"use build_sam3_multiplex_video_predictor for safetensor loading."
)
if move_to_device and device is None:
device = get_accelerator_device()
construct_on_meta = init_device == "meta"
empty_context = init_empty_weights(include_buffers=True) if construct_on_meta else nullcontext()
dtype_context = _default_dtype(torch.bfloat16) if construct_on_meta else nullcontext()
device_context = _device_context(None if construct_on_meta else init_device)
with empty_context, dtype_context, device_context:
# Build multiplex-specific components
maskmem_backbone = _create_multiplex_maskmem_backbone(
multiplex_count=multiplex_count
)
transformer = _create_multiplex_transformer(
use_fa3=use_fa3, use_rope_real=use_rope_real
)
tri_neck = _create_multiplex_tri_backbone(
compile_mode="max-autotune" if compile else None,
use_fa3=use_fa3,
use_rope_real=use_rope_real,
init_device=None,
)
backbone = TriHeadVisionOnly(
visual=tri_neck,
n_features=256,
scalp=0,
)
multiplex_controller = MultiplexController(
multiplex_count=multiplex_count,
eval_multiplex_count=multiplex_count,
)
# Build the multiplex model (use demo class for init_state and other demo methods)
from .model.video_tracking_multiplex_demo import Sam3VideoTrackingMultiplexDemo
model = Sam3VideoTrackingMultiplexDemo(
backbone=backbone,
transformer=transformer,
maskmem_backbone=maskmem_backbone,
multiplex_controller=multiplex_controller,
image_size=1008,
backbone_stride=14,
num_maskmem=7,
# Multiplex-specific settings
use_high_res_features_in_sam=True,
use_obj_ptrs_in_encoder=True,
max_obj_ptrs_in_encoder=16,
add_tpos_enc_to_obj_ptrs=True,
proj_tpos_enc_in_obj_ptrs=True,
use_mlp_for_obj_ptr_proj=True,
pred_obj_scores=True,
pred_obj_scores_mlp=True,
fixed_no_obj_ptr=True,
use_no_obj_ptr=True,
use_linear_no_obj_ptr=True,
no_obj_embed_spatial=True,
sincos_tpos_enc=True,
# Multimask settings
multimask_output_in_sam=True,
multimask_output_for_tracking=True,
multimask_min_pt_num=0,
multimask_max_pt_num=1,
use_multimask_token_for_obj_ptr=True,
num_multimask_outputs=3,
# Memory encoder settings
apply_sigmoid_to_mask_logits_for_mem_enc=True,
sigmoid_scale_for_mem_enc=2.0,
sigmoid_bias_for_mem_enc=-1.0,
non_overlap_masks_for_mem_enc=False,
# Suppression/conditional embeddings
add_output_suppression_embeddings=True,
add_object_conditional_embeddings=False,
condition_as_mask_input=True,
condition_as_mask_input_fg=1.0,
condition_as_mask_input_bg=0.0,
# Memory settings
use_maskmem_tpos_v2=True,
save_image_features=True,
randomness_fix=True,
# Interaction settings
use_mask_input_as_output_without_sam=True,
directly_add_no_mem_embed=True,
iou_prediction_use_sigmoid=False,
forward_backbone_per_frame_for_eval=True,
offload_output_to_cpu_for_eval=False,
trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval,
max_cond_frames_in_attn=4,
# Dynamic multiplex settings
is_dynamic_model=True,
# SAM mask decoder extra args
sam_mask_decoder_extra_args={
"dynamic_multimask_via_stability": True,
"dynamic_multimask_stability_delta": 0.05,
"dynamic_multimask_stability_thresh": 0.98,
},
compile_all_components=compile,
use_memory_selection=False,
)
if move_to_device:
model.to(device=device)
return model
def build_sam3_text_encoder(
checkpoint_path: Optional[str] = None,
bpe_path: Optional[str] = None,
):
_setup_tf32()
if bpe_path is None:
bpe_path = pkg_resources.resource_filename(
"sam3", "assets/bpe_simple_vocab_16e6.txt.gz"
)
if checkpoint_path is None:
raise FileNotFoundError(
"SAM3.1 text encoding requires sam3.1_multiplex_bf16.safetensors."
)
target_device = torch.device("cpu")
with init_empty_weights(include_buffers=True), _default_dtype(torch.bfloat16):
text_encoder = _create_text_encoder(bpe_path, init_device=None)
_refresh_text_encoder_buffers_(text_encoder, target_device, torch.bfloat16)
prefix = "detector.backbone.language_backbone."
_load_sam3_safetensors_model(
text_encoder,
checkpoint_path,
include_prefixes=(prefix,),
strip_prefix=prefix,
)
return text_encoder.eval()
def build_sam3_multiplex_video_predictor(
checkpoint_path: Optional[str] = None,
bpe_path: Optional[str] = None,
max_num_objects: int = 16,
multiplex_count: int = 16,
use_fa3: bool = True,
use_rope_real: bool = True,
compile: bool = False,
warm_up: bool = False,
session_expiration_sec: int = 1200,
default_output_prob_thresh: float = 0.5,
async_loading_frames: bool = True,
include_text_encoder: bool = True,
postprocess_batch_size: int = 16,
use_batched_grounding: bool = True,
batched_grounding_batch_size: int = 16,
trim_past_non_cond_mem_for_eval: bool = False,
fill_hole_area: int = 0,
manual_model_loading: bool = False,
):
"""
Build a fully-initialized Sam3MultiplexVideoPredictor.
This is the recommended entry point for SAM 3.1 multiplex video tracking.
It builds the full model stack (tracker + detector + demo model), loads
the checkpoint, and wraps everything in Sam3MultiplexVideoPredictor with
handle_request / handle_stream_request API.
Args:
checkpoint_path: Path to the merged multiplex checkpoint
bpe_path: Path to the BPE tokenizer vocabulary
max_num_objects: Maximum number of tracked objects
multiplex_count: Number of objects per multiplex bucket
use_fa3: Whether to use FlashAttention 3
use_rope_real: Whether to use real-valued RoPE (for compile compat)
compile: Whether to enable torch.compile on model components
warm_up: Whether to run warm-up compilation (requires compile=True)
session_expiration_sec: Session expiration timeout in seconds
default_output_prob_thresh: Default probability threshold for output masks
async_loading_frames: Whether to load frames asynchronously
Returns:
Sam3MultiplexVideoPredictor: The fully-initialized predictor
"""
if bpe_path is None:
bpe_path = pkg_resources.resource_filename(
"sam3", "assets/bpe_simple_vocab_16e6.txt.gz"
)
from .model.sam3_multiplex_base import Sam3MultiplexPredictorWrapper
from .model.sam3_multiplex_detector import Sam3MultiplexDetector
from .model.sam3_multiplex_tracking import (
Sam3MultiplexTrackingWithInteractivity,
)
from .model.sam3_multiplex_video_predictor import Sam3MultiplexVideoPredictor
target_device = torch.device("cpu")
with init_empty_weights(include_buffers=True), _default_dtype(torch.bfloat16):
# Build tracker
tracker_model = build_sam3_multiplex_video_model(
checkpoint_path=None,
load_from_HF=False,
multiplex_count=multiplex_count,
use_fa3=use_fa3,
use_rope_real=use_rope_real,
trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval,
compile=False,
strict_state_dict_loading=False,
init_device=None,
move_to_device=False,
)
del tracker_model.backbone
tracker_model.backbone = None
sam2_predictor = Sam3MultiplexPredictorWrapper(
model=tracker_model,
per_obj_inference=False,
# Keep tracker fill disabled; MatAnyone applies hole filling to final binary masks.
fill_hole_area=0,
is_multiplex=True,
is_multiplex_dynamic=True,
)
# Build detector
tri_neck = _create_multiplex_tri_backbone(
compile_mode=None, use_fa3=use_fa3, use_rope_real=use_rope_real, init_device=None
)
text_encoder = _create_text_encoder(bpe_path, init_device=None) if include_text_encoder else None
backbone = SAM3VLBackboneTri(scalp=0, visual=tri_neck, text=text_encoder)
transformer = _create_sam3_transformer(use_fa3=use_fa3)
segmentation_head = _create_segmentation_head(use_fa3=use_fa3)
geometry_encoder = _create_geometry_encoder()
dot_prod_scoring = _create_dot_product_scoring()
detector = Sam3MultiplexDetector(
num_feature_levels=1,
backbone=backbone,
transformer=transformer,
segmentation_head=segmentation_head,
semantic_segmentation_head=None,
input_geometry_encoder=geometry_encoder,
use_early_fusion=True,
use_dot_prod_scoring=True,
dot_prod_scoring=dot_prod_scoring,
supervise_joint_box_scores=True,
is_multiplex=True,
)
# Assemble demo model
demo_model = Sam3MultiplexTrackingWithInteractivity(
tracker=sam2_predictor,
detector=detector,
score_threshold_detection=0.4,
det_nms_thresh=0.1,
det_nms_use_iom=True,
assoc_iou_thresh=0.1,
new_det_thresh=0.65,
hotstart_delay=15,
hotstart_unmatch_thresh=8,
hotstart_dup_thresh=8,
suppress_unmatched_only_within_hotstart=False,
suppress_overlapping_based_on_recent_occlusion_threshold=0.7,
suppress_det_close_to_boundary=True,
# Keep tracker fill disabled; MatAnyone applies hole filling to final binary masks.
fill_hole_area=0,
recondition_every_nth_frame=16,
use_iom_recondition=True,
iom_thresh_recondition=0.5,
masklet_confirmation_enable=True,
reconstruction_bbox_iou_thresh=-1,
reconstruction_bbox_det_score=0.8,
max_num_objects=max_num_objects,
postprocess_batch_size=postprocess_batch_size,
use_batched_grounding=use_batched_grounding,
batched_grounding_batch_size=batched_grounding_batch_size,
max_num_kboxes=0,
sprinkle_removal_area=0,
is_multiplex=True,
image_size=1008,
image_mean=(0.5, 0.5, 0.5),
image_std=(0.5, 0.5, 0.5),
compile_model=compile,
)
# Load checkpoint on CPU; CUDA residency is handled just in time by the predictor.
if checkpoint_path is None:
raise FileNotFoundError("SAM3.1 requires sam3.1_multiplex_bf16.safetensors.")
if checkpoint_path is not None:
exclude_prefixes = (
("detector.backbone.language_backbone.",) if not include_text_encoder else None
)
_refresh_rope_buffers_(demo_model, target_device, torch.bfloat16)
if include_text_encoder:
_refresh_text_encoder_buffers_(demo_model, target_device, torch.bfloat16)
_load_sam3_safetensors_model(
demo_model, checkpoint_path, exclude_prefixes=exclude_prefixes
)
demo_model.eval()
# Wrap in predictor
predictor = Sam3MultiplexVideoPredictor(
model=demo_model,
session_expiration_sec=session_expiration_sec,
default_output_prob_thresh=default_output_prob_thresh,
async_loading_frames=async_loading_frames,
warm_up=warm_up,
manual_model_loading=manual_model_loading,
)
return predictor
def build_sam3_predictor(
checkpoint_path: Optional[str] = None,
bpe_path: Optional[str] = None,
version: str = "sam3.1", # "sam3" or "sam3.1"
compile: bool = False,
warm_up: bool = False,
# SAM 3.1 specific
max_num_objects: int = 16,
multiplex_count: int = 16,
# Common
use_fa3: bool = True,
use_rope_real: bool = True,
async_loading_frames: bool = True,
include_text_encoder: bool = True,
trim_past_non_cond_mem_for_eval: bool = False,
**kwargs,
):
"""
Build a SAM3 video predictor.
Args:
checkpoint_path: Path to model checkpoint
bpe_path: Path to BPE tokenizer vocabulary
version: Model version - "sam3" for base or "sam3.1" for multiplex
compile: Enable torch.compile for ~2x speedup (SAM 3.1 only currently)
warm_up: Run warm-up compilation passes
max_num_objects: Maximum tracked objects (SAM 3.1 only)
multiplex_count: Objects per multiplex bucket (SAM 3.1 only)
use_fa3: Use Flash Attention 3
use_rope_real: Use real-valued RoPE
async_loading_frames: Load video frames asynchronously
**kwargs: Additional arguments passed to the underlying builder
Returns:
A predictor with handle_request() and handle_stream_request() API.
Both versions support: start_session, add_prompt, propagate_in_video,
remove_object, reset_session, close_session.
Example:
# SAM 3.1 (auto-downloads from HuggingFace):
predictor = build_sam3_predictor(version="sam3.1", compile=True)
# SAM 3 (auto-downloads from HuggingFace):
predictor = build_sam3_predictor(version="sam3")
# Or with a local checkpoint:
predictor = build_sam3_predictor(checkpoint_path="path/to/sam3.1_multiplex_bf16.safetensors", version="sam3.1")
# Both use the same API:
response = predictor.handle_request({"type": "start_session", "resource_path": video_dir})
session_id = response["session_id"]
predictor.handle_request({"type": "add_prompt", "session_id": session_id, "frame_index": 0, "text": "person"})
for out in predictor.handle_stream_request({"type": "propagate_in_video", "session_id": session_id}):
masks = out["out_binary_masks"]
"""
if version == "sam3.1":
return build_sam3_multiplex_video_predictor(
checkpoint_path=checkpoint_path,
bpe_path=bpe_path,
max_num_objects=max_num_objects,
multiplex_count=multiplex_count,
use_fa3=use_fa3,
use_rope_real=use_rope_real,
compile=compile,
warm_up=warm_up,
async_loading_frames=async_loading_frames,
include_text_encoder=include_text_encoder,
trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval,
**kwargs,
)
elif version == "sam3":
return build_sam3_video_predictor(
checkpoint_path=checkpoint_path,
bpe_path=bpe_path,
compile=compile,
async_loading_frames=async_loading_frames,
**kwargs,
)
else:
raise ValueError(f"Unknown version: {version!r}. Use 'sam3' or 'sam3.1'.")

Xet Storage Details

Size:
49.7 kB
·
Xet hash:
a15e26b06e3d93636ff17bc155ac2fe9891c053820bd217d19985e5a696be348

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.