| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved | |
| # pyre-unsafe | |
| import os | |
| from contextlib import contextmanager, nullcontext | |
| from pathlib import Path | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from accelerate import init_empty_weights | |
| from huggingface_hub import hf_hub_download | |
| from iopath.common.file_io import g_pathmgr | |
| from mmgp import offload | |
| from .model.decoder import ( | |
| DecoupledTransformerDecoderLayerv2, | |
| SimpleRoPEAttention, | |
| TransformerDecoder, | |
| TransformerDecoderLayer, | |
| TransformerDecoderLayerv2, | |
| TransformerEncoderCrossAttention, | |
| TransformerEncoderDecoupledCrossAttention, | |
| ) | |
| from .model.encoder import TransformerEncoderFusion, TransformerEncoderLayer | |
| from .model.geometry_encoders import SequenceGeometryEncoder | |
| from .model.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead | |
| from .model.memory import ( | |
| CXBlock, | |
| SimpleFuser, | |
| SimpleMaskDownSampler, | |
| SimpleMaskEncoder, | |
| ) | |
| from .model.model_misc import ( | |
| DotProductScoring, | |
| MLP, | |
| MultiheadAttentionWrapper as MultiheadAttention, | |
| TransformerWrapper, | |
| ) | |
| from .model.multiplex_utils import MultiplexController | |
| from .model.necks import Sam3DualViTDetNeck, Sam3TriViTDetNeck | |
| from .model.position_encoding import PositionEmbeddingSine | |
| from .model.sam1_task_predictor import SAM3InteractiveImagePredictor | |
| from .model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU | |
| from .model.sam3_tracking_predictor import Sam3TrackerPredictor | |
| from .model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity | |
| from .model.sam3_video_predictor import Sam3VideoPredictorMultiGPU | |
| from .model.text_encoder_ve import VETextEncoder | |
| from .model.tokenizer_ve import SimpleTokenizer | |
| from .model.video_tracking_multiplex import VideoTrackingDynamicMultiplex | |
| from .model.vitdet import ViT | |
| from .model.vl_combiner import SAM3VLBackbone, SAM3VLBackboneTri, TriHeadVisionOnly | |
| from .model.device_utils import get_accelerator_device | |
| from .sam.transformer import RoPEAttention | |
| # Setup TensorFloat-32 for Ampere GPUs if available | |
| def _setup_tf32() -> None: | |
| """Enable TensorFloat-32 for Ampere GPUs if available.""" | |
| if torch.cuda.is_available(): | |
| device_props = torch.cuda.get_device_properties(0) | |
| if device_props.major >= 8: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| class _PackageResources: | |
| def resource_filename(package, resource): | |
| return os.fspath(Path(__file__).resolve().parent / resource) | |
| pkg_resources = _PackageResources() | |
| def _default_dtype(dtype: torch.dtype): | |
| previous_dtype = torch.get_default_dtype() | |
| torch.set_default_dtype(dtype) | |
| try: | |
| yield | |
| finally: | |
| torch.set_default_dtype(previous_dtype) | |
| def _device_context(device): | |
| return torch.device(device) if device is not None else nullcontext() | |
| def _remap_checkpoint_key(key: str) -> str: | |
| if key.startswith("sam3_model."): | |
| return "detector." + key[len("sam3_model.") :] | |
| if key.startswith("sam2_predictor."): | |
| return "tracker." + key[len("sam2_predictor.") :] | |
| return key | |
| def _keep_checkpoint_key(key: str, include_prefixes=None, exclude_prefixes=None) -> bool: | |
| if include_prefixes is not None and not key.startswith(tuple(include_prefixes)): | |
| return False | |
| if exclude_prefixes is not None and key.startswith(tuple(exclude_prefixes)): | |
| return False | |
| return True | |
| def _preprocess_sam3_state_dict( | |
| model: nn.Module, | |
| include_prefixes=None, | |
| exclude_prefixes=None, | |
| strip_prefix=None, | |
| ): | |
| def preprocess(state_dict, quantization_map=None, tied_weights_map=None): | |
| if "model" in state_dict and isinstance(state_dict["model"], dict): | |
| state_dict = state_dict["model"] | |
| filtered = {} | |
| for key, value in state_dict.items(): | |
| new_key = _remap_checkpoint_key(key) | |
| if not _keep_checkpoint_key(new_key, include_prefixes, exclude_prefixes): | |
| continue | |
| if strip_prefix is not None and new_key.startswith(strip_prefix): | |
| new_key = new_key[len(strip_prefix) :] | |
| filtered[new_key] = value | |
| for key, value in model.state_dict().items(): | |
| if key not in filtered and torch.is_tensor(value) and not value.is_meta: | |
| filtered[key] = value | |
| return filtered | |
| return preprocess | |
| def _load_sam3_safetensors_model( | |
| model: nn.Module, | |
| checkpoint_path: str, | |
| *, | |
| include_prefixes=None, | |
| exclude_prefixes=None, | |
| strip_prefix=None, | |
| ) -> None: | |
| if checkpoint_path is None: | |
| raise FileNotFoundError("SAM3.1 requires a bf16 safetensors checkpoint.") | |
| if Path(checkpoint_path).suffix.lower() != ".safetensors": | |
| raise ValueError(f"SAM3.1 checkpoints must be bf16 safetensors files, got: {checkpoint_path}") | |
| offload.load_model_data( | |
| model, | |
| checkpoint_path, | |
| writable_tensors=False, | |
| preprocess_sd=_preprocess_sam3_state_dict( | |
| model, include_prefixes, exclude_prefixes, strip_prefix | |
| ), | |
| default_dtype=torch.bfloat16, | |
| ) | |
| def _refresh_rope_buffers_(module: nn.Module, device: torch.device, dtype: torch.dtype) -> None: | |
| for submodule in module.modules(): | |
| if getattr(submodule, "use_rope", False) and getattr(submodule, "use_rope_real", False) and hasattr(submodule, "_setup_rope_freqs"): | |
| freqs = getattr(submodule, "freqs_cis_real", None) | |
| if torch.is_tensor(freqs) and freqs.is_meta: | |
| submodule._setup_rope_freqs() | |
| submodule.freqs_cis_real = submodule.freqs_cis_real.to(device=device, dtype=dtype) | |
| submodule.freqs_cis_imag = submodule.freqs_cis_imag.to(device=device, dtype=dtype) | |
| elif getattr(submodule, "use_rope_real", False) and hasattr(submodule, "compute_cis") and hasattr(submodule, "freqs_cis_real"): | |
| freqs = getattr(submodule, "freqs_cis_real", None) | |
| if not torch.is_tensor(freqs) or not freqs.is_meta: | |
| continue | |
| side = int(freqs.shape[0] ** 0.5) | |
| real, imag = submodule.compute_cis(end_x=side, end_y=side, device=device) | |
| submodule.freqs_cis_real = real.to(dtype=dtype) | |
| submodule.freqs_cis_imag = imag.to(dtype=dtype) | |
| def _refresh_text_encoder_buffers_(module: nn.Module, device: torch.device, dtype: torch.dtype) -> None: | |
| for submodule in module.modules(): | |
| attn_mask = getattr(submodule, "attn_mask", None) | |
| if torch.is_tensor(attn_mask) and attn_mask.is_meta and hasattr(submodule, "build_causal_mask"): | |
| submodule.attn_mask = submodule.build_causal_mask().to(device=device, dtype=dtype) | |
| def _create_position_encoding(precompute_resolution=None): | |
| """Create position encoding for visual backbone.""" | |
| return PositionEmbeddingSine( | |
| num_pos_feats=256, | |
| normalize=True, | |
| scale=None, | |
| temperature=10000, | |
| precompute_resolution=precompute_resolution, | |
| ) | |
| def _create_vit_backbone(compile_mode=None, use_fa3=False, use_rope_real=True): | |
| """Create ViT backbone for visual feature extraction.""" | |
| return ViT( | |
| img_size=1008, | |
| pretrain_img_size=336, | |
| patch_size=14, | |
| embed_dim=1024, | |
| depth=32, | |
| num_heads=16, | |
| mlp_ratio=4.625, | |
| norm_layer="LayerNorm", | |
| drop_path_rate=0.1, | |
| qkv_bias=True, | |
| use_abs_pos=True, | |
| tile_abs_pos=True, | |
| global_att_blocks=(7, 15, 23, 31), | |
| rel_pos_blocks=(), | |
| use_rope=True, | |
| use_interp_rope=True, | |
| window_size=24, | |
| pretrain_use_cls_token=True, | |
| retain_cls_token=False, | |
| ln_pre=True, | |
| ln_post=False, | |
| return_interm_layers=False, | |
| bias_patch_embed=False, | |
| compile_mode=compile_mode, | |
| use_fa3=use_fa3, | |
| use_rope_real=use_rope_real, | |
| ) | |
| def _create_vit_neck(position_encoding, vit_backbone, enable_inst_interactivity=False): | |
| """Create ViT neck for feature pyramid.""" | |
| return Sam3DualViTDetNeck( | |
| position_encoding=position_encoding, | |
| d_model=256, | |
| scale_factors=[4.0, 2.0, 1.0, 0.5], | |
| trunk=vit_backbone, | |
| add_sam2_neck=enable_inst_interactivity, | |
| ) | |
| def _create_vl_backbone(vit_neck, text_encoder): | |
| """Create visual-language backbone.""" | |
| return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1) | |
| def _create_transformer_encoder(use_fa3=False) -> TransformerEncoderFusion: | |
| """Create transformer encoder with its layer.""" | |
| encoder_layer = TransformerEncoderLayer( | |
| activation="relu", | |
| d_model=256, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| pos_enc_at_attn=True, | |
| pos_enc_at_cross_attn_keys=False, | |
| pos_enc_at_cross_attn_queries=False, | |
| pre_norm=True, | |
| self_attention=MultiheadAttention( | |
| num_heads=8, | |
| dropout=0.1, | |
| embed_dim=256, | |
| batch_first=True, | |
| use_fa3=use_fa3, | |
| ), | |
| cross_attention=MultiheadAttention( | |
| num_heads=8, | |
| dropout=0.1, | |
| embed_dim=256, | |
| batch_first=True, | |
| use_fa3=use_fa3, | |
| ), | |
| ) | |
| encoder = TransformerEncoderFusion( | |
| layer=encoder_layer, | |
| num_layers=6, | |
| d_model=256, | |
| num_feature_levels=1, | |
| frozen=False, | |
| use_act_checkpoint=True, | |
| add_pooled_text_to_img_feat=False, | |
| pool_text_with_mask=True, | |
| ) | |
| return encoder | |
| def _create_transformer_decoder(use_fa3=False) -> TransformerDecoder: | |
| """Create transformer decoder with its layer.""" | |
| decoder_layer = TransformerDecoderLayer( | |
| activation="relu", | |
| d_model=256, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| cross_attention=MultiheadAttention( | |
| num_heads=8, | |
| dropout=0.1, | |
| embed_dim=256, | |
| use_fa3=use_fa3, | |
| ), | |
| n_heads=8, | |
| use_text_cross_attention=True, | |
| ) | |
| decoder = TransformerDecoder( | |
| layer=decoder_layer, | |
| num_layers=6, | |
| num_queries=200, | |
| return_intermediate=True, | |
| box_refine=True, | |
| num_o2m_queries=0, | |
| dac=True, | |
| boxRPB="log", | |
| d_model=256, | |
| frozen=False, | |
| interaction_layer=None, | |
| dac_use_selfatt_ln=True, | |
| resolution=1008, | |
| stride=14, | |
| use_act_checkpoint=True, | |
| presence_token=True, | |
| ) | |
| return decoder | |
| def _create_dot_product_scoring(): | |
| """Create dot product scoring module.""" | |
| prompt_mlp = MLP( | |
| input_dim=256, | |
| hidden_dim=2048, | |
| output_dim=256, | |
| num_layers=2, | |
| dropout=0.1, | |
| residual=True, | |
| out_norm=nn.LayerNorm(256), | |
| ) | |
| return DotProductScoring(d_model=256, d_proj=256, prompt_mlp=prompt_mlp) | |
| def _create_segmentation_head(compile_mode=None, use_fa3=False): | |
| """Create segmentation head with pixel decoder.""" | |
| pixel_decoder = PixelDecoder( | |
| num_upsampling_stages=3, | |
| interpolation_mode="nearest", | |
| hidden_dim=256, | |
| compile_mode=compile_mode, | |
| ) | |
| cross_attend_prompt = MultiheadAttention( | |
| num_heads=8, | |
| dropout=0, | |
| embed_dim=256, | |
| use_fa3=use_fa3, | |
| ) | |
| segmentation_head = UniversalSegmentationHead( | |
| hidden_dim=256, | |
| upsampling_stages=3, | |
| aux_masks=False, | |
| presence_head=False, | |
| dot_product_scorer=None, | |
| act_ckpt=True, | |
| cross_attend_prompt=cross_attend_prompt, | |
| pixel_decoder=pixel_decoder, | |
| ) | |
| return segmentation_head | |
| def _create_geometry_encoder(): | |
| """Create geometry encoder with all its components.""" | |
| # Create position encoding for geometry encoder | |
| geo_pos_enc = _create_position_encoding() | |
| # Create CX block for fuser | |
| cx_block = CXBlock( | |
| dim=256, | |
| kernel_size=7, | |
| padding=3, | |
| layer_scale_init_value=1.0e-06, | |
| use_dwconv=True, | |
| ) | |
| # Create geometry encoder layer | |
| geo_layer = TransformerEncoderLayer( | |
| activation="relu", | |
| d_model=256, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| pos_enc_at_attn=False, | |
| pre_norm=True, | |
| self_attention=MultiheadAttention( | |
| num_heads=8, | |
| dropout=0.1, | |
| embed_dim=256, | |
| batch_first=False, | |
| ), | |
| pos_enc_at_cross_attn_queries=False, | |
| pos_enc_at_cross_attn_keys=True, | |
| cross_attention=MultiheadAttention( | |
| num_heads=8, | |
| dropout=0.1, | |
| embed_dim=256, | |
| batch_first=False, | |
| ), | |
| ) | |
| # Create geometry encoder | |
| input_geometry_encoder = SequenceGeometryEncoder( | |
| pos_enc=geo_pos_enc, | |
| encode_boxes_as_points=False, | |
| points_direct_project=True, | |
| points_pool=True, | |
| points_pos_enc=True, | |
| boxes_direct_project=True, | |
| boxes_pool=True, | |
| boxes_pos_enc=True, | |
| d_model=256, | |
| num_layers=3, | |
| layer=geo_layer, | |
| use_act_ckpt=True, | |
| add_cls=True, | |
| add_post_encode_proj=True, | |
| ) | |
| return input_geometry_encoder | |
| def _create_sam3_model( | |
| backbone, | |
| transformer, | |
| input_geometry_encoder, | |
| segmentation_head, | |
| dot_prod_scoring, | |
| inst_interactive_predictor, | |
| eval_mode, | |
| ): | |
| """Create the SAM3 image model.""" | |
| common_params = { | |
| "backbone": backbone, | |
| "transformer": transformer, | |
| "input_geometry_encoder": input_geometry_encoder, | |
| "segmentation_head": segmentation_head, | |
| "num_feature_levels": 1, | |
| "o2m_mask_predict": True, | |
| "dot_prod_scoring": dot_prod_scoring, | |
| "use_instance_query": False, | |
| "multimask_output": True, | |
| "inst_interactive_predictor": inst_interactive_predictor, | |
| } | |
| matcher = None | |
| if not eval_mode: | |
| from .train.matcher import BinaryHungarianMatcherV2 | |
| matcher = BinaryHungarianMatcherV2( | |
| focal=True, | |
| cost_class=2.0, | |
| cost_bbox=5.0, | |
| cost_giou=2.0, | |
| alpha=0.25, | |
| gamma=2, | |
| stable=False, | |
| ) | |
| common_params["matcher"] = matcher | |
| model = Sam3Image(**common_params) | |
| return model | |
| def _create_tracker_maskmem_backbone(): | |
| """Create the SAM3 Tracker memory encoder.""" | |
| # Position encoding for mask memory backbone | |
| position_encoding = PositionEmbeddingSine( | |
| num_pos_feats=64, | |
| normalize=True, | |
| scale=None, | |
| temperature=10000, | |
| precompute_resolution=1008, | |
| ) | |
| # Mask processing components | |
| mask_downsampler = SimpleMaskDownSampler( | |
| kernel_size=3, stride=2, padding=1, interpol_size=[1152, 1152] | |
| ) | |
| cx_block_layer = CXBlock( | |
| dim=256, | |
| kernel_size=7, | |
| padding=3, | |
| layer_scale_init_value=1.0e-06, | |
| use_dwconv=True, | |
| ) | |
| fuser = SimpleFuser(layer=cx_block_layer, num_layers=2) | |
| maskmem_backbone = SimpleMaskEncoder( | |
| out_dim=64, | |
| position_encoding=position_encoding, | |
| mask_downsampler=mask_downsampler, | |
| fuser=fuser, | |
| ) | |
| return maskmem_backbone | |
| def _create_tracker_transformer(): | |
| """Create the SAM3 Tracker transformer components.""" | |
| # Self attention | |
| self_attention = RoPEAttention( | |
| embedding_dim=256, | |
| num_heads=1, | |
| downsample_rate=1, | |
| dropout=0.1, | |
| rope_theta=10000.0, | |
| feat_sizes=[72, 72], | |
| use_fa3=False, | |
| use_rope_real=True, | |
| ) | |
| # Cross attention | |
| cross_attention = RoPEAttention( | |
| embedding_dim=256, | |
| num_heads=1, | |
| downsample_rate=1, | |
| dropout=0.1, | |
| kv_in_dim=64, | |
| rope_theta=10000.0, | |
| feat_sizes=[72, 72], | |
| rope_k_repeat=True, | |
| use_fa3=False, | |
| use_rope_real=True, | |
| ) | |
| # Encoder layer | |
| encoder_layer = TransformerDecoderLayerv2( | |
| cross_attention_first=False, | |
| activation="relu", | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| pos_enc_at_attn=False, | |
| pre_norm=True, | |
| self_attention=self_attention, | |
| d_model=256, | |
| pos_enc_at_cross_attn_keys=True, | |
| pos_enc_at_cross_attn_queries=False, | |
| cross_attention=cross_attention, | |
| ) | |
| # Encoder | |
| encoder = TransformerEncoderCrossAttention( | |
| remove_cross_attention_layers=[], | |
| batch_first=True, | |
| d_model=256, | |
| frozen=False, | |
| pos_enc_at_input=True, | |
| layer=encoder_layer, | |
| num_layers=4, | |
| use_act_checkpoint=False, | |
| ) | |
| # Transformer wrapper | |
| transformer = TransformerWrapper( | |
| encoder=encoder, | |
| decoder=None, | |
| d_model=256, | |
| ) | |
| return transformer | |
| def build_tracker( | |
| apply_temporal_disambiguation: bool, with_backbone: bool = False, compile_mode=None | |
| ) -> Sam3TrackerPredictor: | |
| """ | |
| Build the SAM3 Tracker module for video tracking. | |
| Returns: | |
| Sam3TrackerPredictor: Wrapped SAM3 Tracker module | |
| """ | |
| # Create model components | |
| maskmem_backbone = _create_tracker_maskmem_backbone() | |
| transformer = _create_tracker_transformer() | |
| backbone = None | |
| if with_backbone: | |
| vision_backbone = _create_vision_backbone(compile_mode=compile_mode) | |
| backbone = SAM3VLBackbone(scalp=1, visual=vision_backbone, text=None) | |
| # Create the Tracker module | |
| model = Sam3TrackerPredictor( | |
| image_size=1008, | |
| num_maskmem=7, | |
| backbone=backbone, | |
| backbone_stride=14, | |
| transformer=transformer, | |
| maskmem_backbone=maskmem_backbone, | |
| # SAM parameters | |
| multimask_output_in_sam=True, | |
| # Evaluation | |
| forward_backbone_per_frame_for_eval=True, | |
| trim_past_non_cond_mem_for_eval=False, | |
| # Multimask | |
| multimask_output_for_tracking=True, | |
| multimask_min_pt_num=0, | |
| multimask_max_pt_num=1, | |
| # Additional settings | |
| always_start_from_first_ann_frame=False, | |
| # Mask overlap | |
| non_overlap_masks_for_mem_enc=False, | |
| non_overlap_masks_for_output=False, | |
| max_cond_frames_in_attn=4, | |
| offload_output_to_cpu_for_eval=False, | |
| # SAM decoder settings | |
| sam_mask_decoder_extra_args={ | |
| "dynamic_multimask_via_stability": True, | |
| "dynamic_multimask_stability_delta": 0.05, | |
| "dynamic_multimask_stability_thresh": 0.98, | |
| }, | |
| clear_non_cond_mem_around_input=True, | |
| fill_hole_area=0, | |
| use_memory_selection=apply_temporal_disambiguation, | |
| ) | |
| return model | |
| def _create_text_encoder(bpe_path: str, init_device="cpu") -> VETextEncoder: | |
| """Create SAM3 text encoder.""" | |
| with _device_context(init_device): | |
| tokenizer = SimpleTokenizer(bpe_path=bpe_path) | |
| return VETextEncoder( | |
| tokenizer=tokenizer, | |
| d_model=256, | |
| width=1024, | |
| heads=16, | |
| layers=24, | |
| ) | |
| def _create_vision_backbone( | |
| compile_mode=None, enable_inst_interactivity=True | |
| ) -> Sam3DualViTDetNeck: | |
| """Create SAM3 visual backbone with ViT and neck.""" | |
| # Position encoding | |
| position_encoding = _create_position_encoding(precompute_resolution=1008) | |
| # ViT backbone | |
| vit_backbone: ViT = _create_vit_backbone(compile_mode=compile_mode) | |
| vit_neck: Sam3DualViTDetNeck = _create_vit_neck( | |
| position_encoding, | |
| vit_backbone, | |
| enable_inst_interactivity=enable_inst_interactivity, | |
| ) | |
| # Visual neck | |
| return vit_neck | |
| def _create_sam3_transformer( | |
| has_presence_token: bool = True, use_fa3: bool = False | |
| ) -> TransformerWrapper: | |
| """Create SAM3 transformer encoder and decoder.""" | |
| encoder: TransformerEncoderFusion = _create_transformer_encoder(use_fa3=use_fa3) | |
| decoder: TransformerDecoder = _create_transformer_decoder(use_fa3=use_fa3) | |
| return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256) | |
| def _load_checkpoint(model, checkpoint_path): | |
| """Load model checkpoint from file.""" | |
| with g_pathmgr.open(checkpoint_path, "rb") as f: | |
| ckpt = torch.load(f, map_location="cpu", weights_only=True) | |
| if "model" in ckpt and isinstance(ckpt["model"], dict): | |
| ckpt = ckpt["model"] | |
| sam3_image_ckpt = { | |
| k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k | |
| } | |
| if model.inst_interactive_predictor is not None: | |
| sam3_image_ckpt.update( | |
| { | |
| k.replace("tracker.", "inst_interactive_predictor.model."): v | |
| for k, v in ckpt.items() | |
| if "tracker" in k | |
| } | |
| ) | |
| missing_keys, _ = model.load_state_dict(sam3_image_ckpt, strict=False) | |
| if len(missing_keys) > 0: | |
| print( | |
| f"loaded {checkpoint_path} and found " | |
| f"missing and/or unexpected keys:\n{missing_keys=}" | |
| ) | |
| def _setup_device_and_mode(model, device, eval_mode): | |
| """Setup model device and evaluation mode.""" | |
| device = torch.device(device) | |
| if device.type != "cpu": | |
| model = model.to(device=device) | |
| if eval_mode: | |
| model.eval() | |
| return model | |
| def build_sam3_image_model( | |
| bpe_path=None, | |
| device=None, | |
| eval_mode=True, | |
| checkpoint_path=None, | |
| load_from_HF=True, | |
| enable_segmentation=True, | |
| enable_inst_interactivity=False, | |
| compile=False, | |
| ): | |
| """ | |
| Build SAM3 image model | |
| Args: | |
| bpe_path: Path to the BPE tokenizer vocabulary | |
| device: Device to load the model on ('cuda' or 'cpu') | |
| eval_mode: Whether to set the model to evaluation mode | |
| checkpoint_path: Optional path to model checkpoint | |
| enable_segmentation: Whether to enable segmentation head | |
| enable_inst_interactivity: Whether to enable instance interactivity (SAM 1 task) | |
| compile_mode: To enable compilation, set to "default" | |
| Returns: | |
| A SAM3 image model | |
| """ | |
| _setup_tf32() | |
| if device is None: | |
| device = get_accelerator_device() | |
| if bpe_path is None: | |
| bpe_path = pkg_resources.resource_filename( | |
| "sam3", "assets/bpe_simple_vocab_16e6.txt.gz" | |
| ) | |
| # Create visual components | |
| compile_mode = "default" if compile else None | |
| vision_encoder = _create_vision_backbone( | |
| compile_mode=compile_mode, enable_inst_interactivity=enable_inst_interactivity | |
| ) | |
| # Create text components | |
| text_encoder = _create_text_encoder(bpe_path) | |
| # Create visual-language backbone | |
| backbone = _create_vl_backbone(vision_encoder, text_encoder) | |
| # Create transformer components | |
| transformer = _create_sam3_transformer() | |
| # Create dot product scoring | |
| dot_prod_scoring = _create_dot_product_scoring() | |
| # Create segmentation head if enabled | |
| segmentation_head = ( | |
| _create_segmentation_head(compile_mode=compile_mode) | |
| if enable_segmentation | |
| else None | |
| ) | |
| # Create geometry encoder | |
| input_geometry_encoder = _create_geometry_encoder() | |
| if enable_inst_interactivity: | |
| sam3_pvs_base = build_tracker(apply_temporal_disambiguation=False) | |
| inst_predictor = SAM3InteractiveImagePredictor(sam3_pvs_base) | |
| else: | |
| inst_predictor = None | |
| # Create the SAM3 model | |
| model = _create_sam3_model( | |
| backbone, | |
| transformer, | |
| input_geometry_encoder, | |
| segmentation_head, | |
| dot_prod_scoring, | |
| inst_predictor, | |
| eval_mode, | |
| ) | |
| if load_from_HF and checkpoint_path is None: | |
| checkpoint_path = download_ckpt_from_hf(version="sam3") | |
| # Load checkpoint if provided | |
| if checkpoint_path is not None: | |
| _load_checkpoint(model, checkpoint_path) | |
| # Setup device and mode | |
| model = _setup_device_and_mode(model, device, eval_mode) | |
| return model | |
| def download_ckpt_from_hf(version="sam3"): | |
| """Download model checkpoint from HuggingFace Hub. | |
| Args: | |
| version: "sam3" or "sam3.1" | |
| """ | |
| if version == "sam3.1": | |
| raise FileNotFoundError("SAM3.1 uses the local sam3.1_multiplex_bf16.safetensors checkpoint.") | |
| else: | |
| repo_id = "facebook/sam3" | |
| ckpt_name = "sam3.pt" | |
| cfg_name = "config.json" | |
| _ = hf_hub_download(repo_id=repo_id, filename=cfg_name) | |
| checkpoint_path = hf_hub_download(repo_id=repo_id, filename=ckpt_name) | |
| return checkpoint_path | |
| def build_sam3_video_model( | |
| checkpoint_path: Optional[str] = None, | |
| load_from_HF=True, | |
| bpe_path: Optional[str] = None, | |
| has_presence_token: bool = True, | |
| geo_encoder_use_img_cross_attn: bool = True, | |
| strict_state_dict_loading: bool = True, | |
| apply_temporal_disambiguation: bool = True, | |
| device=None, | |
| compile=False, | |
| ) -> Sam3VideoInferenceWithInstanceInteractivity: | |
| """ | |
| Build SAM3 dense tracking model. | |
| Args: | |
| checkpoint_path: Optional path to checkpoint file | |
| bpe_path: Path to the BPE tokenizer file | |
| Returns: | |
| Sam3VideoInferenceWithInstanceInteractivity: The instantiated dense tracking model | |
| """ | |
| _setup_tf32() | |
| if device is None: | |
| device = get_accelerator_device() | |
| if bpe_path is None: | |
| bpe_path = pkg_resources.resource_filename( | |
| "sam3", "assets/bpe_simple_vocab_16e6.txt.gz" | |
| ) | |
| # Build Tracker module | |
| tracker = build_tracker(apply_temporal_disambiguation=apply_temporal_disambiguation) | |
| # Build Detector components | |
| visual_neck = _create_vision_backbone() | |
| text_encoder = _create_text_encoder(bpe_path) | |
| backbone = SAM3VLBackbone(scalp=1, visual=visual_neck, text=text_encoder) | |
| transformer = _create_sam3_transformer(has_presence_token=has_presence_token) | |
| segmentation_head: UniversalSegmentationHead = _create_segmentation_head() | |
| input_geometry_encoder = _create_geometry_encoder() | |
| # Create main dot product scoring | |
| main_dot_prod_mlp = MLP( | |
| input_dim=256, | |
| hidden_dim=2048, | |
| output_dim=256, | |
| num_layers=2, | |
| dropout=0.1, | |
| residual=True, | |
| out_norm=nn.LayerNorm(256), | |
| ) | |
| main_dot_prod_scoring = DotProductScoring( | |
| d_model=256, d_proj=256, prompt_mlp=main_dot_prod_mlp | |
| ) | |
| # Build Detector module | |
| detector = Sam3ImageOnVideoMultiGPU( | |
| num_feature_levels=1, | |
| backbone=backbone, | |
| transformer=transformer, | |
| segmentation_head=segmentation_head, | |
| semantic_segmentation_head=None, | |
| input_geometry_encoder=input_geometry_encoder, | |
| use_early_fusion=True, | |
| use_dot_prod_scoring=True, | |
| dot_prod_scoring=main_dot_prod_scoring, | |
| supervise_joint_box_scores=has_presence_token, | |
| ) | |
| # Build the main SAM3 video model | |
| if apply_temporal_disambiguation: | |
| model = Sam3VideoInferenceWithInstanceInteractivity( | |
| detector=detector, | |
| tracker=tracker, | |
| score_threshold_detection=0.5, | |
| assoc_iou_thresh=0.1, | |
| det_nms_thresh=0.1, | |
| new_det_thresh=0.7, | |
| hotstart_delay=15, | |
| hotstart_unmatch_thresh=8, | |
| hotstart_dup_thresh=8, | |
| suppress_unmatched_only_within_hotstart=True, | |
| min_trk_keep_alive=-1, | |
| max_trk_keep_alive=30, | |
| init_trk_keep_alive=30, | |
| suppress_overlapping_based_on_recent_occlusion_threshold=0.7, | |
| suppress_det_close_to_boundary=False, | |
| fill_hole_area=16, | |
| recondition_every_nth_frame=16, | |
| masklet_confirmation_enable=False, | |
| decrease_trk_keep_alive_for_empty_masklets=False, | |
| image_size=1008, | |
| image_mean=(0.5, 0.5, 0.5), | |
| image_std=(0.5, 0.5, 0.5), | |
| compile_model=compile, | |
| ) | |
| else: | |
| # a version without any heuristics for ablation studies | |
| model = Sam3VideoInferenceWithInstanceInteractivity( | |
| detector=detector, | |
| tracker=tracker, | |
| score_threshold_detection=0.5, | |
| assoc_iou_thresh=0.1, | |
| det_nms_thresh=0.1, | |
| new_det_thresh=0.7, | |
| hotstart_delay=0, | |
| hotstart_unmatch_thresh=0, | |
| hotstart_dup_thresh=0, | |
| suppress_unmatched_only_within_hotstart=True, | |
| min_trk_keep_alive=-1, | |
| max_trk_keep_alive=30, | |
| init_trk_keep_alive=30, | |
| suppress_overlapping_based_on_recent_occlusion_threshold=0.7, | |
| suppress_det_close_to_boundary=False, | |
| fill_hole_area=16, | |
| recondition_every_nth_frame=0, | |
| masklet_confirmation_enable=False, | |
| decrease_trk_keep_alive_for_empty_masklets=False, | |
| image_size=1008, | |
| image_mean=(0.5, 0.5, 0.5), | |
| image_std=(0.5, 0.5, 0.5), | |
| compile_model=compile, | |
| ) | |
| # Load checkpoint if provided | |
| if load_from_HF and checkpoint_path is None: | |
| checkpoint_path = download_ckpt_from_hf(version="sam3") | |
| if checkpoint_path is not None: | |
| with g_pathmgr.open(checkpoint_path, "rb") as f: | |
| ckpt = torch.load(f, map_location="cpu", weights_only=True) | |
| if "model" in ckpt and isinstance(ckpt["model"], dict): | |
| ckpt = ckpt["model"] | |
| missing_keys, unexpected_keys = model.load_state_dict( | |
| ckpt, strict=strict_state_dict_loading | |
| ) | |
| if missing_keys: | |
| print(f"Missing keys: {missing_keys}") | |
| if unexpected_keys: | |
| print(f"Unexpected keys: {unexpected_keys}") | |
| model.to(device=device) | |
| return model | |
| def build_sam3_video_predictor(*model_args, gpus_to_use=None, **model_kwargs): | |
| return Sam3VideoPredictorMultiGPU( | |
| *model_args, gpus_to_use=gpus_to_use, **model_kwargs | |
| ) | |
| def _create_multiplex_maskmem_backbone(multiplex_count=16): | |
| """Create the multiplex memory encoder with per-object mask channels.""" | |
| position_encoding = PositionEmbeddingSine( | |
| num_pos_feats=256, | |
| normalize=True, | |
| scale=None, | |
| temperature=10000, | |
| precompute_resolution=1008, | |
| ) | |
| mask_downsampler = SimpleMaskDownSampler( | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| interpol_size=[1152, 1152], | |
| multiplex_count=multiplex_count, | |
| starting_out_chan=4, | |
| input_channel_multiplier=2, | |
| ) | |
| cx_block_layer = CXBlock( | |
| dim=256, | |
| kernel_size=7, | |
| padding=3, | |
| layer_scale_init_value=1.0e-06, | |
| use_dwconv=True, | |
| ) | |
| fuser = SimpleFuser(layer=cx_block_layer, num_layers=2) | |
| maskmem_backbone = SimpleMaskEncoder( | |
| out_dim=256, | |
| position_encoding=position_encoding, | |
| mask_downsampler=mask_downsampler, | |
| fuser=fuser, | |
| ) | |
| return maskmem_backbone | |
| def _create_multiplex_transformer(use_fa3=False, use_rope_real=True): | |
| """Create the decoupled transformer for multiplex memory attention.""" | |
| self_attention_rope = SimpleRoPEAttention( | |
| d_model=256, | |
| num_heads=8, | |
| dropout_p=0.1, | |
| rope_theta=10000.0, | |
| feat_sizes=[72, 72], | |
| use_fa3=use_fa3, | |
| use_rope_real=use_rope_real, | |
| ) | |
| cross_attention_rope = SimpleRoPEAttention( | |
| d_model=256, | |
| num_heads=8, | |
| dropout_p=0.1, | |
| rope_theta=10000.0, | |
| feat_sizes=[72, 72], | |
| rope_k_repeat=True, | |
| use_fa3=use_fa3, | |
| use_rope_real=use_rope_real, | |
| ) | |
| encoder_layer = DecoupledTransformerDecoderLayerv2( | |
| activation="gelu", | |
| d_model=256, | |
| num_heads=8, | |
| dropout=0.1, | |
| dim_feedforward=2048, | |
| pos_enc_at_attn=False, | |
| pre_norm=True, | |
| pos_enc_at_cross_attn_keys=True, | |
| pos_enc_at_cross_attn_queries=False, | |
| self_attention_rope=self_attention_rope, | |
| cross_attention_rope=cross_attention_rope, | |
| ) | |
| encoder = TransformerEncoderDecoupledCrossAttention( | |
| d_model=256, | |
| frozen=False, | |
| pos_enc_at_input=True, | |
| use_image_in_output=False, | |
| layer=encoder_layer, | |
| num_layers=4, | |
| use_act_checkpoint=False, | |
| batch_first=True, | |
| ) | |
| transformer = TransformerWrapper( | |
| encoder=encoder, | |
| decoder=None, | |
| d_model=256, | |
| ) | |
| return transformer | |
| def _create_multiplex_tri_backbone( | |
| compile_mode=None, use_fa3=False, use_rope_real=True, init_device="cpu" | |
| ): | |
| """Create the TriHead vision backbone for multiplex model.""" | |
| with _device_context(init_device): | |
| position_encoding = _create_position_encoding(precompute_resolution=1008) | |
| vit_backbone = _create_vit_backbone( | |
| compile_mode=compile_mode, use_fa3=use_fa3, use_rope_real=use_rope_real | |
| ) | |
| return Sam3TriViTDetNeck( | |
| trunk=vit_backbone, | |
| position_encoding=position_encoding, | |
| d_model=256, | |
| scale_factors=[4.0, 2.0, 1.0], | |
| ) | |
| def build_sam3_multiplex_video_model( | |
| checkpoint_path: Optional[str] = None, | |
| load_from_HF=True, | |
| multiplex_count: int = 16, | |
| use_fa3: bool = False, | |
| use_rope_real: bool = True, | |
| trim_past_non_cond_mem_for_eval: bool = False, | |
| strict_state_dict_loading: bool = True, | |
| device=None, | |
| init_device="cpu", | |
| move_to_device: bool = True, | |
| compile=False, | |
| ): | |
| """ | |
| Build SAM3 multiplex video tracking model. | |
| Args: | |
| checkpoint_path: Optional path to checkpoint file | |
| multiplex_count: Number of objects per multiplex bucket | |
| use_fa3: Whether to use FlashAttention 3 | |
| use_rope_real: Whether to use real-valued RoPE (for compile compat) | |
| strict_state_dict_loading: Whether to use strict state dict loading | |
| device: Device to place model on | |
| compile: Whether to compile model components | |
| Returns: | |
| VideoTrackingDynamicMultiplex: The instantiated multiplex tracking model | |
| """ | |
| _setup_tf32() | |
| if load_from_HF and checkpoint_path is None: | |
| raise FileNotFoundError( | |
| "SAM3.1 uses the local sam3.1_multiplex_bf16.safetensors checkpoint." | |
| ) | |
| if checkpoint_path is not None: | |
| raise ValueError( | |
| "Standalone SAM3.1 tracker checkpoint loading is not supported; " | |
| "use build_sam3_multiplex_video_predictor for safetensor loading." | |
| ) | |
| if move_to_device and device is None: | |
| device = get_accelerator_device() | |
| construct_on_meta = init_device == "meta" | |
| empty_context = init_empty_weights(include_buffers=True) if construct_on_meta else nullcontext() | |
| dtype_context = _default_dtype(torch.bfloat16) if construct_on_meta else nullcontext() | |
| device_context = _device_context(None if construct_on_meta else init_device) | |
| with empty_context, dtype_context, device_context: | |
| # Build multiplex-specific components | |
| maskmem_backbone = _create_multiplex_maskmem_backbone( | |
| multiplex_count=multiplex_count | |
| ) | |
| transformer = _create_multiplex_transformer( | |
| use_fa3=use_fa3, use_rope_real=use_rope_real | |
| ) | |
| tri_neck = _create_multiplex_tri_backbone( | |
| compile_mode="max-autotune" if compile else None, | |
| use_fa3=use_fa3, | |
| use_rope_real=use_rope_real, | |
| init_device=None, | |
| ) | |
| backbone = TriHeadVisionOnly( | |
| visual=tri_neck, | |
| n_features=256, | |
| scalp=0, | |
| ) | |
| multiplex_controller = MultiplexController( | |
| multiplex_count=multiplex_count, | |
| eval_multiplex_count=multiplex_count, | |
| ) | |
| # Build the multiplex model (use demo class for init_state and other demo methods) | |
| from .model.video_tracking_multiplex_demo import Sam3VideoTrackingMultiplexDemo | |
| model = Sam3VideoTrackingMultiplexDemo( | |
| backbone=backbone, | |
| transformer=transformer, | |
| maskmem_backbone=maskmem_backbone, | |
| multiplex_controller=multiplex_controller, | |
| image_size=1008, | |
| backbone_stride=14, | |
| num_maskmem=7, | |
| # Multiplex-specific settings | |
| use_high_res_features_in_sam=True, | |
| use_obj_ptrs_in_encoder=True, | |
| max_obj_ptrs_in_encoder=16, | |
| add_tpos_enc_to_obj_ptrs=True, | |
| proj_tpos_enc_in_obj_ptrs=True, | |
| use_mlp_for_obj_ptr_proj=True, | |
| pred_obj_scores=True, | |
| pred_obj_scores_mlp=True, | |
| fixed_no_obj_ptr=True, | |
| use_no_obj_ptr=True, | |
| use_linear_no_obj_ptr=True, | |
| no_obj_embed_spatial=True, | |
| sincos_tpos_enc=True, | |
| # Multimask settings | |
| multimask_output_in_sam=True, | |
| multimask_output_for_tracking=True, | |
| multimask_min_pt_num=0, | |
| multimask_max_pt_num=1, | |
| use_multimask_token_for_obj_ptr=True, | |
| num_multimask_outputs=3, | |
| # Memory encoder settings | |
| apply_sigmoid_to_mask_logits_for_mem_enc=True, | |
| sigmoid_scale_for_mem_enc=2.0, | |
| sigmoid_bias_for_mem_enc=-1.0, | |
| non_overlap_masks_for_mem_enc=False, | |
| # Suppression/conditional embeddings | |
| add_output_suppression_embeddings=True, | |
| add_object_conditional_embeddings=False, | |
| condition_as_mask_input=True, | |
| condition_as_mask_input_fg=1.0, | |
| condition_as_mask_input_bg=0.0, | |
| # Memory settings | |
| use_maskmem_tpos_v2=True, | |
| save_image_features=True, | |
| randomness_fix=True, | |
| # Interaction settings | |
| use_mask_input_as_output_without_sam=True, | |
| directly_add_no_mem_embed=True, | |
| iou_prediction_use_sigmoid=False, | |
| forward_backbone_per_frame_for_eval=True, | |
| offload_output_to_cpu_for_eval=False, | |
| trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, | |
| max_cond_frames_in_attn=4, | |
| # Dynamic multiplex settings | |
| is_dynamic_model=True, | |
| # SAM mask decoder extra args | |
| sam_mask_decoder_extra_args={ | |
| "dynamic_multimask_via_stability": True, | |
| "dynamic_multimask_stability_delta": 0.05, | |
| "dynamic_multimask_stability_thresh": 0.98, | |
| }, | |
| compile_all_components=compile, | |
| use_memory_selection=False, | |
| ) | |
| if move_to_device: | |
| model.to(device=device) | |
| return model | |
| def build_sam3_text_encoder( | |
| checkpoint_path: Optional[str] = None, | |
| bpe_path: Optional[str] = None, | |
| ): | |
| _setup_tf32() | |
| if bpe_path is None: | |
| bpe_path = pkg_resources.resource_filename( | |
| "sam3", "assets/bpe_simple_vocab_16e6.txt.gz" | |
| ) | |
| if checkpoint_path is None: | |
| raise FileNotFoundError( | |
| "SAM3.1 text encoding requires sam3.1_multiplex_bf16.safetensors." | |
| ) | |
| target_device = torch.device("cpu") | |
| with init_empty_weights(include_buffers=True), _default_dtype(torch.bfloat16): | |
| text_encoder = _create_text_encoder(bpe_path, init_device=None) | |
| _refresh_text_encoder_buffers_(text_encoder, target_device, torch.bfloat16) | |
| prefix = "detector.backbone.language_backbone." | |
| _load_sam3_safetensors_model( | |
| text_encoder, | |
| checkpoint_path, | |
| include_prefixes=(prefix,), | |
| strip_prefix=prefix, | |
| ) | |
| return text_encoder.eval() | |
| def build_sam3_multiplex_video_predictor( | |
| checkpoint_path: Optional[str] = None, | |
| bpe_path: Optional[str] = None, | |
| max_num_objects: int = 16, | |
| multiplex_count: int = 16, | |
| use_fa3: bool = True, | |
| use_rope_real: bool = True, | |
| compile: bool = False, | |
| warm_up: bool = False, | |
| session_expiration_sec: int = 1200, | |
| default_output_prob_thresh: float = 0.5, | |
| async_loading_frames: bool = True, | |
| include_text_encoder: bool = True, | |
| postprocess_batch_size: int = 16, | |
| use_batched_grounding: bool = True, | |
| batched_grounding_batch_size: int = 16, | |
| trim_past_non_cond_mem_for_eval: bool = False, | |
| fill_hole_area: int = 0, | |
| manual_model_loading: bool = False, | |
| ): | |
| """ | |
| Build a fully-initialized Sam3MultiplexVideoPredictor. | |
| This is the recommended entry point for SAM 3.1 multiplex video tracking. | |
| It builds the full model stack (tracker + detector + demo model), loads | |
| the checkpoint, and wraps everything in Sam3MultiplexVideoPredictor with | |
| handle_request / handle_stream_request API. | |
| Args: | |
| checkpoint_path: Path to the merged multiplex checkpoint | |
| bpe_path: Path to the BPE tokenizer vocabulary | |
| max_num_objects: Maximum number of tracked objects | |
| multiplex_count: Number of objects per multiplex bucket | |
| use_fa3: Whether to use FlashAttention 3 | |
| use_rope_real: Whether to use real-valued RoPE (for compile compat) | |
| compile: Whether to enable torch.compile on model components | |
| warm_up: Whether to run warm-up compilation (requires compile=True) | |
| session_expiration_sec: Session expiration timeout in seconds | |
| default_output_prob_thresh: Default probability threshold for output masks | |
| async_loading_frames: Whether to load frames asynchronously | |
| Returns: | |
| Sam3MultiplexVideoPredictor: The fully-initialized predictor | |
| """ | |
| if bpe_path is None: | |
| bpe_path = pkg_resources.resource_filename( | |
| "sam3", "assets/bpe_simple_vocab_16e6.txt.gz" | |
| ) | |
| from .model.sam3_multiplex_base import Sam3MultiplexPredictorWrapper | |
| from .model.sam3_multiplex_detector import Sam3MultiplexDetector | |
| from .model.sam3_multiplex_tracking import ( | |
| Sam3MultiplexTrackingWithInteractivity, | |
| ) | |
| from .model.sam3_multiplex_video_predictor import Sam3MultiplexVideoPredictor | |
| target_device = torch.device("cpu") | |
| with init_empty_weights(include_buffers=True), _default_dtype(torch.bfloat16): | |
| # Build tracker | |
| tracker_model = build_sam3_multiplex_video_model( | |
| checkpoint_path=None, | |
| load_from_HF=False, | |
| multiplex_count=multiplex_count, | |
| use_fa3=use_fa3, | |
| use_rope_real=use_rope_real, | |
| trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, | |
| compile=False, | |
| strict_state_dict_loading=False, | |
| init_device=None, | |
| move_to_device=False, | |
| ) | |
| del tracker_model.backbone | |
| tracker_model.backbone = None | |
| sam2_predictor = Sam3MultiplexPredictorWrapper( | |
| model=tracker_model, | |
| per_obj_inference=False, | |
| # Keep tracker fill disabled; MatAnyone applies hole filling to final binary masks. | |
| fill_hole_area=0, | |
| is_multiplex=True, | |
| is_multiplex_dynamic=True, | |
| ) | |
| # Build detector | |
| tri_neck = _create_multiplex_tri_backbone( | |
| compile_mode=None, use_fa3=use_fa3, use_rope_real=use_rope_real, init_device=None | |
| ) | |
| text_encoder = _create_text_encoder(bpe_path, init_device=None) if include_text_encoder else None | |
| backbone = SAM3VLBackboneTri(scalp=0, visual=tri_neck, text=text_encoder) | |
| transformer = _create_sam3_transformer(use_fa3=use_fa3) | |
| segmentation_head = _create_segmentation_head(use_fa3=use_fa3) | |
| geometry_encoder = _create_geometry_encoder() | |
| dot_prod_scoring = _create_dot_product_scoring() | |
| detector = Sam3MultiplexDetector( | |
| num_feature_levels=1, | |
| backbone=backbone, | |
| transformer=transformer, | |
| segmentation_head=segmentation_head, | |
| semantic_segmentation_head=None, | |
| input_geometry_encoder=geometry_encoder, | |
| use_early_fusion=True, | |
| use_dot_prod_scoring=True, | |
| dot_prod_scoring=dot_prod_scoring, | |
| supervise_joint_box_scores=True, | |
| is_multiplex=True, | |
| ) | |
| # Assemble demo model | |
| demo_model = Sam3MultiplexTrackingWithInteractivity( | |
| tracker=sam2_predictor, | |
| detector=detector, | |
| score_threshold_detection=0.4, | |
| det_nms_thresh=0.1, | |
| det_nms_use_iom=True, | |
| assoc_iou_thresh=0.1, | |
| new_det_thresh=0.65, | |
| hotstart_delay=15, | |
| hotstart_unmatch_thresh=8, | |
| hotstart_dup_thresh=8, | |
| suppress_unmatched_only_within_hotstart=False, | |
| suppress_overlapping_based_on_recent_occlusion_threshold=0.7, | |
| suppress_det_close_to_boundary=True, | |
| # Keep tracker fill disabled; MatAnyone applies hole filling to final binary masks. | |
| fill_hole_area=0, | |
| recondition_every_nth_frame=16, | |
| use_iom_recondition=True, | |
| iom_thresh_recondition=0.5, | |
| masklet_confirmation_enable=True, | |
| reconstruction_bbox_iou_thresh=-1, | |
| reconstruction_bbox_det_score=0.8, | |
| max_num_objects=max_num_objects, | |
| postprocess_batch_size=postprocess_batch_size, | |
| use_batched_grounding=use_batched_grounding, | |
| batched_grounding_batch_size=batched_grounding_batch_size, | |
| max_num_kboxes=0, | |
| sprinkle_removal_area=0, | |
| is_multiplex=True, | |
| image_size=1008, | |
| image_mean=(0.5, 0.5, 0.5), | |
| image_std=(0.5, 0.5, 0.5), | |
| compile_model=compile, | |
| ) | |
| # Load checkpoint on CPU; CUDA residency is handled just in time by the predictor. | |
| if checkpoint_path is None: | |
| raise FileNotFoundError("SAM3.1 requires sam3.1_multiplex_bf16.safetensors.") | |
| if checkpoint_path is not None: | |
| exclude_prefixes = ( | |
| ("detector.backbone.language_backbone.",) if not include_text_encoder else None | |
| ) | |
| _refresh_rope_buffers_(demo_model, target_device, torch.bfloat16) | |
| if include_text_encoder: | |
| _refresh_text_encoder_buffers_(demo_model, target_device, torch.bfloat16) | |
| _load_sam3_safetensors_model( | |
| demo_model, checkpoint_path, exclude_prefixes=exclude_prefixes | |
| ) | |
| demo_model.eval() | |
| # Wrap in predictor | |
| predictor = Sam3MultiplexVideoPredictor( | |
| model=demo_model, | |
| session_expiration_sec=session_expiration_sec, | |
| default_output_prob_thresh=default_output_prob_thresh, | |
| async_loading_frames=async_loading_frames, | |
| warm_up=warm_up, | |
| manual_model_loading=manual_model_loading, | |
| ) | |
| return predictor | |
| def build_sam3_predictor( | |
| checkpoint_path: Optional[str] = None, | |
| bpe_path: Optional[str] = None, | |
| version: str = "sam3.1", # "sam3" or "sam3.1" | |
| compile: bool = False, | |
| warm_up: bool = False, | |
| # SAM 3.1 specific | |
| max_num_objects: int = 16, | |
| multiplex_count: int = 16, | |
| # Common | |
| use_fa3: bool = True, | |
| use_rope_real: bool = True, | |
| async_loading_frames: bool = True, | |
| include_text_encoder: bool = True, | |
| trim_past_non_cond_mem_for_eval: bool = False, | |
| **kwargs, | |
| ): | |
| """ | |
| Build a SAM3 video predictor. | |
| Args: | |
| checkpoint_path: Path to model checkpoint | |
| bpe_path: Path to BPE tokenizer vocabulary | |
| version: Model version - "sam3" for base or "sam3.1" for multiplex | |
| compile: Enable torch.compile for ~2x speedup (SAM 3.1 only currently) | |
| warm_up: Run warm-up compilation passes | |
| max_num_objects: Maximum tracked objects (SAM 3.1 only) | |
| multiplex_count: Objects per multiplex bucket (SAM 3.1 only) | |
| use_fa3: Use Flash Attention 3 | |
| use_rope_real: Use real-valued RoPE | |
| async_loading_frames: Load video frames asynchronously | |
| **kwargs: Additional arguments passed to the underlying builder | |
| Returns: | |
| A predictor with handle_request() and handle_stream_request() API. | |
| Both versions support: start_session, add_prompt, propagate_in_video, | |
| remove_object, reset_session, close_session. | |
| Example: | |
| # SAM 3.1 (auto-downloads from HuggingFace): | |
| predictor = build_sam3_predictor(version="sam3.1", compile=True) | |
| # SAM 3 (auto-downloads from HuggingFace): | |
| predictor = build_sam3_predictor(version="sam3") | |
| # Or with a local checkpoint: | |
| predictor = build_sam3_predictor(checkpoint_path="path/to/sam3.1_multiplex_bf16.safetensors", version="sam3.1") | |
| # Both use the same API: | |
| response = predictor.handle_request({"type": "start_session", "resource_path": video_dir}) | |
| session_id = response["session_id"] | |
| predictor.handle_request({"type": "add_prompt", "session_id": session_id, "frame_index": 0, "text": "person"}) | |
| for out in predictor.handle_stream_request({"type": "propagate_in_video", "session_id": session_id}): | |
| masks = out["out_binary_masks"] | |
| """ | |
| if version == "sam3.1": | |
| return build_sam3_multiplex_video_predictor( | |
| checkpoint_path=checkpoint_path, | |
| bpe_path=bpe_path, | |
| max_num_objects=max_num_objects, | |
| multiplex_count=multiplex_count, | |
| use_fa3=use_fa3, | |
| use_rope_real=use_rope_real, | |
| compile=compile, | |
| warm_up=warm_up, | |
| async_loading_frames=async_loading_frames, | |
| include_text_encoder=include_text_encoder, | |
| trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, | |
| **kwargs, | |
| ) | |
| elif version == "sam3": | |
| return build_sam3_video_predictor( | |
| checkpoint_path=checkpoint_path, | |
| bpe_path=bpe_path, | |
| compile=compile, | |
| async_loading_frames=async_loading_frames, | |
| **kwargs, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown version: {version!r}. Use 'sam3' or 'sam3.1'.") | |
Xet Storage Details
- Size:
- 49.7 kB
- Xet hash:
- a15e26b06e3d93636ff17bc155ac2fe9891c053820bd217d19985e5a696be348
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.