| import importlib | |
| import importlib.util | |
| import os | |
| from pathlib import Path | |
| from typing import Iterable | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from shared.utils import files_locator as fl | |
| from .logger import get_logger | |
| from .model.device_utils import accelerator_autocast, empty_accelerator_cache, get_accelerator_device, is_accelerator_device | |
| _PACKAGE_ROOT = Path(__file__).resolve().parent | |
| _SAM3_FOLDER = "sam3" | |
| _SAM3_CHECKPOINT_NAME = "sam3.1_multiplex_bf16.safetensors" | |
| _SAM3_BPE_NAME = "bpe_simple_vocab_16e6.txt.gz" | |
| KEEP_VIDEO_FRAMES_ON_CUDA = True | |
| _TEXT_ENCODER_CACHE = None | |
| _TEXT_ENCODER_CACHE_KEY = None | |
| logger = get_logger(__name__) | |
| def _cleanup(): | |
| import gc | |
| gc.collect() | |
| empty_accelerator_cache() | |
| def _load_model_builder(): | |
| try: | |
| return importlib.import_module(".model_builder", package=__package__) | |
| except ModuleNotFoundError as exc: | |
| if exc.name != importlib.util.resolve_name(".model_builder", __package__): | |
| raise | |
| raise FileNotFoundError("SAM3.1 code was not found under preprocessing/sam3.") | |
| def _checkpoint_path(): | |
| for candidate in [ | |
| os.path.join(_SAM3_FOLDER, _SAM3_CHECKPOINT_NAME), | |
| os.path.join("sam3.1", _SAM3_CHECKPOINT_NAME), | |
| _SAM3_CHECKPOINT_NAME, | |
| ]: | |
| checkpoint = fl.locate_file(candidate, error_if_none=False) | |
| if checkpoint is not None: | |
| return checkpoint, "sam3.1" | |
| checkpoint = _PACKAGE_ROOT / _SAM3_CHECKPOINT_NAME | |
| if checkpoint.is_file(): | |
| return os.fspath(checkpoint), "sam3.1" | |
| raise FileNotFoundError("SAM3.1 bf16 safetensors checkpoint was not found by files_locator as sam3/sam3.1_multiplex_bf16.safetensors, sam3.1/sam3.1_multiplex_bf16.safetensors, or sam3.1_multiplex_bf16.safetensors, nor under preprocessing/sam3.") | |
| def _bpe_path(): | |
| for candidate in [ | |
| os.path.join(_SAM3_FOLDER, _SAM3_BPE_NAME), | |
| os.path.join("sam3.1", _SAM3_BPE_NAME), | |
| _SAM3_BPE_NAME, | |
| ]: | |
| bpe_path = fl.locate_file(candidate, error_if_none=False) | |
| if bpe_path is not None: | |
| return bpe_path | |
| bpe_path = _PACKAGE_ROOT / "assets" / _SAM3_BPE_NAME | |
| if bpe_path.is_file(): | |
| return os.fspath(bpe_path) | |
| raise FileNotFoundError("SAM3 BPE vocabulary was not found by files_locator as sam3/bpe_simple_vocab_16e6.txt.gz, sam3.1/bpe_simple_vocab_16e6.txt.gz, or bpe_simple_vocab_16e6.txt.gz, nor under preprocessing/sam3/assets.") | |
| def _autocast_context(): | |
| return accelerator_autocast() | |
| def _bf16_prompt_payload(value): | |
| if torch.is_tensor(value): | |
| return value.to(dtype=torch.bfloat16) if value.is_floating_point() else value | |
| if isinstance(value, dict): | |
| return {key: _bf16_prompt_payload(item) for key, item in value.items()} | |
| if isinstance(value, list): | |
| return [_bf16_prompt_payload(item) for item in value] | |
| if isinstance(value, tuple): | |
| return tuple(_bf16_prompt_payload(item) for item in value) | |
| return value | |
| def _format_keywords_for_log(keywords: list[str]): | |
| return ", ".join(f"'{keyword}'" for keyword in keywords) | |
| def _to_numpy(value): | |
| if torch.is_tensor(value): | |
| return value.detach().cpu().numpy() | |
| return np.asarray(value) | |
| def _sam3_outputs_to_binary_mask(outputs, height: int, width: int): | |
| if outputs is None or "out_binary_masks" not in outputs: | |
| return np.zeros((height, width), dtype=np.bool_) | |
| masks = _to_numpy(outputs["out_binary_masks"]) | |
| if masks.size == 0: | |
| return np.zeros((height, width), dtype=np.bool_) | |
| if masks.ndim == 2: | |
| masks = masks[None, :, :] | |
| elif masks.ndim == 4 and masks.shape[1] == 1: | |
| masks = masks[:, 0] | |
| elif masks.ndim > 3: | |
| masks = masks.reshape((-1, *masks.shape[-2:])) | |
| if masks.shape[-2:] != (height, width): | |
| masks = np.stack([np.asarray(Image.fromarray(mask.astype(np.uint8)).resize((width, height), resample=Image.Resampling.NEAREST)) for mask in masks], axis=0) | |
| return masks.astype(bool).any(axis=0) | |
| def resolve_sam3_grounding_batch_size(batch_size=None) -> int: | |
| if batch_size is not None: | |
| batch_size = int(batch_size) | |
| if batch_size > 0: | |
| return batch_size | |
| if not torch.cuda.is_available(): | |
| return 2 | |
| total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) | |
| return 4 if total_vram_gb >= 8 else 2 | |
| def _encode_text_outputs(text_encoder, captions: list[str], device: torch.device): | |
| masks, memories, embeds = [], [], [] | |
| if is_accelerator_device(device): | |
| text_encoder.to(device=device, dtype=torch.bfloat16) | |
| for caption in captions: | |
| with torch.inference_mode(), _autocast_context(): | |
| text_attention_mask, text_memory, text_embeds = text_encoder([caption], device=device) | |
| masks.append(text_attention_mask.detach().cpu()) | |
| memories.append(text_memory.detach().cpu()) | |
| embeds.append(text_embeds.detach().cpu()) | |
| del text_attention_mask, text_memory, text_embeds | |
| _cleanup() | |
| return { | |
| "language_features": torch.cat(memories, dim=1), | |
| "language_mask": torch.cat(masks, dim=0), | |
| "language_embeds": torch.cat(embeds, dim=1), | |
| } | |
| def _encode_keyword_prompts(model_builder, checkpoint_path: str, bpe_path: str, keywords: list[str], keep_text_encoder_loaded: bool = False): | |
| global _TEXT_ENCODER_CACHE, _TEXT_ENCODER_CACHE_KEY | |
| text_encoder = None | |
| device = get_accelerator_device() | |
| cache_key = (checkpoint_path, bpe_path) | |
| preencoded = {} | |
| try: | |
| if keep_text_encoder_loaded and _TEXT_ENCODER_CACHE is not None and _TEXT_ENCODER_CACHE_KEY == cache_key: | |
| text_encoder = _TEXT_ENCODER_CACHE | |
| else: | |
| text_encoder = model_builder.build_sam3_text_encoder(checkpoint_path=checkpoint_path, bpe_path=bpe_path) | |
| if keep_text_encoder_loaded: | |
| _TEXT_ENCODER_CACHE = text_encoder | |
| _TEXT_ENCODER_CACHE_KEY = cache_key | |
| for keyword in keywords: | |
| preencoded[keyword] = _encode_text_outputs(text_encoder, [keyword, "visual", "geometric"], device) | |
| finally: | |
| if keep_text_encoder_loaded and text_encoder is not None: | |
| text_encoder.to("cpu") | |
| elif text_encoder is not None: | |
| del text_encoder | |
| _cleanup() | |
| return preencoded | |
| def encode_sam3_keyword_prompts(keywords: Iterable[str], keep_text_encoder_loaded: bool = False): | |
| keywords = [str(keyword).strip() for keyword in keywords if str(keyword).strip()] | |
| if len(keywords) == 0: | |
| return {} | |
| model_builder = _load_model_builder() | |
| checkpoint_path, _ = _checkpoint_path() | |
| bpe_path = _bpe_path() | |
| return _encode_keyword_prompts(model_builder, checkpoint_path, bpe_path, keywords, keep_text_encoder_loaded=keep_text_encoder_loaded) | |
| def clear_sam3_text_encoder_cache(): | |
| global _TEXT_ENCODER_CACHE, _TEXT_ENCODER_CACHE_KEY | |
| if _TEXT_ENCODER_CACHE is not None: | |
| del _TEXT_ENCODER_CACHE | |
| _TEXT_ENCODER_CACHE = None | |
| _TEXT_ENCODER_CACHE_KEY = None | |
| _cleanup() | |
| def fill_sam3_binary_mask_holes(mask: np.ndarray, fill_hole_area: int): | |
| fill_hole_area = max(0, int(fill_hole_area)) | |
| if fill_hole_area == 0 or not np.any(mask): | |
| return mask.astype(np.bool_, copy=False) | |
| from .model.sam3_tracker_utils import fill_holes_in_mask_scores | |
| scores = torch.from_numpy(mask.astype(np.float32, copy=False))[None, None] | |
| scores = scores * 2 - 1 | |
| filled = fill_holes_in_mask_scores(scores, max_area=fill_hole_area, fill_holes=True, remove_sprinkles=False) | |
| return filled[0, 0].numpy() > 0 | |
| def _load_predictor( | |
| model_builder=None, | |
| checkpoint_path=None, | |
| bpe_path=None, | |
| version=None, | |
| include_text_encoder=True, | |
| batched_grounding_batch_size=None, | |
| postprocess_batch_size=1, | |
| use_batched_grounding=True, | |
| trim_past_non_cond_mem_for_eval=True, | |
| fill_hole_area: int = 0, | |
| manual_model_loading: bool = False, | |
| ): | |
| model_builder = model_builder or _load_model_builder() | |
| checkpoint_path, version = (checkpoint_path, version) if checkpoint_path is not None and version is not None else _checkpoint_path() | |
| bpe_path = bpe_path or _bpe_path() | |
| grounding_batch_size = resolve_sam3_grounding_batch_size(batched_grounding_batch_size) | |
| return model_builder.build_sam3_predictor(checkpoint_path=checkpoint_path, bpe_path=bpe_path, version=version, use_fa3=False, use_rope_real=True, compile=False, warm_up=False, include_text_encoder=include_text_encoder, postprocess_batch_size=postprocess_batch_size, use_batched_grounding=use_batched_grounding, batched_grounding_batch_size=grounding_batch_size, trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, fill_hole_area=fill_hole_area, manual_model_loading=manual_model_loading) | |
| def load_sam3_mask_predictor( | |
| *, | |
| include_text_encoder: bool = True, | |
| postprocess_batch_size: int = 1, | |
| use_batched_grounding: bool = True, | |
| batched_grounding_batch_size=None, | |
| trim_past_non_cond_mem_for_eval: bool = True, | |
| fill_hole_area: int = 0, | |
| manual_model_loading: bool = False, | |
| ): | |
| model_builder = _load_model_builder() | |
| checkpoint_path, version = _checkpoint_path() | |
| bpe_path = _bpe_path() | |
| return _load_predictor( | |
| model_builder, | |
| checkpoint_path, | |
| bpe_path, | |
| version, | |
| include_text_encoder=include_text_encoder, | |
| batched_grounding_batch_size=batched_grounding_batch_size, | |
| postprocess_batch_size=postprocess_batch_size, | |
| use_batched_grounding=use_batched_grounding, | |
| trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, | |
| fill_hole_area=fill_hole_area, | |
| manual_model_loading=manual_model_loading, | |
| ) | |
| def run_sam3_video( | |
| video: np.ndarray, | |
| keywords: Iterable[str], | |
| *, | |
| include_text_encoder: bool = False, | |
| preencode_text: bool = True, | |
| batched_grounding_batch_size=None, | |
| postprocess_batch_size: int = 1, | |
| use_batched_grounding: bool = True, | |
| trim_past_non_cond_mem_for_eval: bool = True, | |
| keep_video_frames_on_cuda: bool = KEEP_VIDEO_FRAMES_ON_CUDA, | |
| cache_frame_outputs: bool = False, | |
| fill_hole_area: int = 0, | |
| progress_callback=None, | |
| ): | |
| keywords = [str(keyword).strip() for keyword in keywords if str(keyword).strip()] | |
| if len(keywords) == 0: | |
| return np.zeros(video.shape[:3], dtype=np.bool_) | |
| model_builder = _load_model_builder() | |
| checkpoint_path, version = _checkpoint_path() | |
| bpe_path = _bpe_path() | |
| _cleanup() | |
| if version == "sam3.1" and preencode_text: | |
| logger.info("SAM3 encoding keywords before propagation: %s", _format_keywords_for_log(keywords)) | |
| preencoded_prompts = _encode_keyword_prompts(model_builder, checkpoint_path, bpe_path, keywords) | |
| else: | |
| preencoded_prompts = None | |
| video_predictor = _load_predictor( | |
| model_builder, | |
| checkpoint_path, | |
| bpe_path, | |
| version, | |
| include_text_encoder=include_text_encoder or preencoded_prompts is None, | |
| batched_grounding_batch_size=batched_grounding_batch_size, | |
| postprocess_batch_size=postprocess_batch_size, | |
| use_batched_grounding=use_batched_grounding, | |
| trim_past_non_cond_mem_for_eval=trim_past_non_cond_mem_for_eval, | |
| fill_hole_area=0, | |
| ) | |
| num_frames, height, width, _ = video.shape | |
| video_pil = [Image.fromarray(video[i]) for i in range(num_frames)] | |
| session_id = None | |
| response = video_predictor.handle_request({"type": "start_session", "resource_path": video_pil, "offload_video_to_cpu": not keep_video_frames_on_cuda, "cache_frame_outputs": cache_frame_outputs}) | |
| session_id = response["session_id"] | |
| dynamic_mask = np.zeros((num_frames, height, width), dtype=np.bool_) | |
| try: | |
| total_progress_steps = len(keywords) * num_frames | |
| for keyword_index, keyword in enumerate(keywords): | |
| progress_base = keyword_index * num_frames | |
| logger.info("SAM3 keyword currently being processed: '%s'", keyword) | |
| request = {"type": "add_prompt", "session_id": session_id, "frame_index": 0, "text": keyword} | |
| if preencoded_prompts is not None: | |
| request["preencoded_text_outputs"] = _bf16_prompt_payload(preencoded_prompts[keyword]) | |
| with _autocast_context(): | |
| result = video_predictor.handle_request(request) | |
| dynamic_mask[0] |= _sam3_outputs_to_binary_mask(result.get("outputs") if isinstance(result, dict) else None, height, width) | |
| if progress_callback is not None: | |
| progress_callback(progress_base, total_progress_steps) | |
| internal_progress_seen = False | |
| def model_progress_callback(done, total): | |
| nonlocal internal_progress_seen | |
| internal_progress_seen = True | |
| progress_callback(min(progress_base + int(done), total_progress_steps), total_progress_steps) | |
| stream_request = { | |
| "type": "propagate_in_video", | |
| "session_id": session_id, | |
| "propagation_direction": "forward", | |
| "start_frame_index": 0, | |
| "max_frame_num_to_track": num_frames, | |
| } | |
| if progress_callback is not None: | |
| stream_request["progress_callback"] = model_progress_callback | |
| propagated_frames = 0 | |
| for result in video_predictor.handle_stream_request(stream_request): | |
| propagated_frames += 1 | |
| if progress_callback is not None and not internal_progress_seen: | |
| progress_callback(min(progress_base + propagated_frames, total_progress_steps), total_progress_steps) | |
| outputs = result["outputs"] | |
| dynamic_mask[result["frame_index"]] |= _sam3_outputs_to_binary_mask(outputs, height, width) | |
| finally: | |
| if session_id is not None: | |
| video_predictor.handle_request({"type": "close_session", "session_id": session_id}) | |
| video_predictor.shutdown() | |
| del video_predictor | |
| _cleanup() | |
| if fill_hole_area > 0: | |
| dynamic_mask = np.stack([fill_sam3_binary_mask_holes(mask, fill_hole_area) for mask in dynamic_mask], axis=0) | |
| return dynamic_mask | |
Xet Storage Details
- Size:
- 14.4 kB
- Xet hash:
- e36905fb1003a6ce262228c84c2b328d5e4a1c2a2dd2a9c53044a626123a4512
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.