| import os | |
| import re | |
| import time | |
| from pathlib import Path | |
| from typing import Iterable | |
| import numpy as np | |
| import torch | |
| import imageio.v2 as imageio | |
| from PIL import Image, ImageOps | |
| from shared.utils.audio_video import _get_codec_params | |
| from shared.utils.utils import get_resampled_video_transparent, get_video_info_details, has_image_file_extension, rgb_bw_to_rgba_mask, sanitize_file_name | |
| from shared.utils.virtual_media import get_virtual_image, strip_virtual_media_suffix | |
| PROCESS_ID = "magic_mask" | |
| PROCESS_NAME = "Magic Mask" | |
| DOWNLOAD_REPO_ID = "DeepBeepMeep/Wan2.1" | |
| DOWNLOAD_FOLDER = "sam3" | |
| DOWNLOAD_FILES = ["sam3.1_multiplex_bf16.safetensors", "bpe_simple_vocab_16e6.txt.gz"] | |
| DEFAULT_FILL_HOLE_AREA = 2 | |
| DEFAULT_POSTPROCESS_BATCH_SIZE = 1 | |
| OUTPUT_DIR = "mask_outputs" | |
| def parse_keywords(keyword_text: str | Iterable[str]) -> list[str]: | |
| if isinstance(keyword_text, str): | |
| candidates = re.split(r"[\n,;]+", keyword_text) | |
| else: | |
| candidates = keyword_text | |
| return [str(keyword).strip() for keyword in candidates if str(keyword).strip()] | |
| def query_download_def(): | |
| return {"repoId": DOWNLOAD_REPO_ID, "sourceFolderList": [DOWNLOAD_FOLDER], "fileList": [list(DOWNLOAD_FILES)]} | |
| def _fill_hole_area(no_hole): | |
| return DEFAULT_FILL_HOLE_AREA if bool(no_hole) else 0 | |
| def _open_image(image): | |
| if isinstance(image, dict): | |
| image = image.get("path") or image.get("name") or image.get("orig_name") | |
| virtual_image = get_virtual_image(image) if isinstance(image, str) else None | |
| if virtual_image is not None: | |
| image = virtual_image | |
| elif isinstance(image, str): | |
| image = Image.open(strip_virtual_media_suffix(image)) | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| if not isinstance(image, Image.Image): | |
| raise ValueError("Magic Mask needs a control image.") | |
| return ImageOps.exif_transpose(image).convert("RGB") | |
| def _media_path(path): | |
| if isinstance(path, dict): | |
| path = path.get("path") or path.get("name") or path.get("orig_name") | |
| return path | |
| def _video_to_numpy(video_path): | |
| video_path = _media_path(video_path) | |
| if not video_path: | |
| raise ValueError("Magic Mask needs a control video.") | |
| if isinstance(video_path, str) and has_image_file_extension(video_path): | |
| image = _open_image(video_path) | |
| width, height = image.size | |
| return np.asarray(image, dtype=np.uint8)[None], 1, width, height | |
| details = get_video_info_details(video_path) | |
| fps = details.get("fps_float") or details.get("fps") or 1 | |
| width = int(details.get("display_width") or details.get("width") or 0) | |
| height = int(details.get("display_height") or details.get("height") or 0) | |
| frame_count = int(details.get("frame_count") or 1) | |
| frames = get_resampled_video_transparent(video_path, 0, frame_count, fps, bridge="torch") | |
| if torch.is_tensor(frames): | |
| frames = frames.detach().cpu().numpy() | |
| elif hasattr(frames, "asnumpy"): | |
| frames = frames.asnumpy() | |
| else: | |
| frames = np.asarray(frames) | |
| if frames.ndim != 4 or frames.shape[0] == 0: | |
| raise ValueError("Magic Mask could not read any control video frames.") | |
| if frames.shape[-1] > 3: | |
| frames = frames[..., :3] | |
| if frames.shape[-1] == 1: | |
| frames = np.repeat(frames, 3, axis=-1) | |
| if width > 0 and height > 0 and frames.shape[1:3] != (height, width): | |
| frames = np.stack([np.asarray(Image.fromarray(frame).resize((width, height), resample=Image.Resampling.LANCZOS)) for frame in frames], axis=0) | |
| return frames.astype(np.uint8, copy=False), fps, width, height | |
| def _run_sam3(video: np.ndarray, keywords: list[str], batch_size, no_hole, progress_callback=None) -> np.ndarray: | |
| from preprocessing.sam3.preprocessor import run_sam3_video | |
| with torch.inference_mode(): | |
| return run_sam3_video( | |
| video, | |
| keywords, | |
| batched_grounding_batch_size=batch_size, | |
| postprocess_batch_size=DEFAULT_POSTPROCESS_BATCH_SIZE, | |
| use_batched_grounding=True, | |
| fill_hole_area=_fill_hole_area(no_hole), | |
| progress_callback=progress_callback, | |
| ) | |
| def prepare_image_mask_input(image) -> tuple[Image.Image, np.ndarray]: | |
| image = _open_image(image) | |
| return image, np.asarray(image, dtype=np.uint8)[None] | |
| def prepare_video_mask_input(video_path) -> tuple[str, np.ndarray, int]: | |
| video_path = _media_path(video_path) | |
| if not video_path: | |
| raise ValueError("Magic Mask needs a control video.") | |
| video, fps, _, _ = _video_to_numpy(video_path) | |
| return video_path, video, fps | |
| def generate_keyword_masks(video: np.ndarray, keyword_text: str | Iterable[str], *, batch_size=None, no_hole=True, progress_callback=None) -> np.ndarray: | |
| keywords = parse_keywords(keyword_text) | |
| if len(keywords) == 0: | |
| return np.zeros(video.shape[:3], dtype=np.bool_) | |
| return _run_sam3(video, keywords, batch_size, no_hole, progress_callback=progress_callback) | |
| def merge_keyword_masks(current_mask: np.ndarray | None, keyword_mask: np.ndarray) -> np.ndarray: | |
| keyword_mask = keyword_mask.astype(bool, copy=False) | |
| return keyword_mask.copy() if current_mask is None else (current_mask | keyword_mask) | |
| def finalize_masks(mask: np.ndarray, *, negative_mask=False) -> np.ndarray: | |
| if negative_mask: | |
| mask = ~mask | |
| return mask | |
| def mask_to_image(mask: np.ndarray) -> Image.Image: | |
| return Image.fromarray(mask.astype(np.uint8) * 255, mode="L") | |
| def _magic_mask_video_codec_params(): | |
| params = dict(_get_codec_params("libx264_10", "mp4")) | |
| params["macro_block_size"] = 1 | |
| if params.get("pixelformat") == "yuv420p": | |
| params["pixelformat"] = "yuv444p" | |
| return params | |
| def save_mask_video(video_path: str, masks: np.ndarray, fps: float, keywords: list[str], *, codec_type=None, output_dir=OUTPUT_DIR, abort_callback=None) -> str: | |
| # codec_type is kept for compatibility; Magic Mask outputs are always MP4 libx264_10. | |
| masks = masks.astype(np.uint8) * 255 | |
| mask_frames = np.repeat(masks[..., None], 3, axis=-1) | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| stem = Path(strip_virtual_media_suffix(video_path)).stem | |
| keywords_suffix = truncate_keywords_for_path(keywords) | |
| output_path = Path(output_dir) / f"{sanitize_file_name(stem)}_magic_mask_{keywords_suffix}_{time.strftime('%Y%m%d_%H%M%S')}.mp4" | |
| output_path = os.fspath(output_path) | |
| writer = imageio.get_writer(output_path, fps=fps, ffmpeg_log_level="error", **_magic_mask_video_codec_params()) | |
| try: | |
| for frame in mask_frames: | |
| if abort_callback is not None: | |
| abort_callback() | |
| writer.append_data(frame) | |
| finally: | |
| writer.close() | |
| return output_path | |
| def generate_image_mask(image, keyword_text, *, batch_size=None, no_hole=True, negative_mask=False) -> tuple[Image.Image, Image.Image, list[str]]: | |
| keywords = parse_keywords(keyword_text) | |
| if len(keywords) == 0: | |
| raise ValueError("Enter at least one keyword.") | |
| image, video = prepare_image_mask_input(image) | |
| mask = finalize_masks(_run_sam3(video, keywords, batch_size, no_hole)[0], negative_mask=negative_mask) | |
| mask_image = mask_to_image(mask) | |
| return image, mask_image, keywords | |
| def generate_video_mask(video_path, keyword_text, *, batch_size=None, no_hole=True, negative_mask=False, codec_type=None, output_dir=OUTPUT_DIR) -> tuple[str, list[str]]: | |
| keywords = parse_keywords(keyword_text) | |
| if len(keywords) == 0: | |
| raise ValueError("Enter at least one keyword.") | |
| video_path, video, fps = prepare_video_mask_input(video_path) | |
| masks = finalize_masks(_run_sam3(video, keywords, batch_size, no_hole), negative_mask=negative_mask) | |
| return save_mask_video(video_path, masks, fps, keywords, output_dir=output_dir), keywords | |
| def truncate_keywords_for_path(keywords: list[str]) -> str: | |
| suffix = sanitize_file_name("_".join(keywords), "_").strip("_") | |
| return suffix[:40] or "mask" | |
| def build_image_editor_value(background: Image.Image, mask_image: Image.Image): | |
| return {"background": background, "composite": None, "layers": [rgb_bw_to_rgba_mask(mask_image)]} | |
Xet Storage Details
- Size:
- 8.21 kB
- Xet hash:
- 66485fb940e99ae327e4d906dbffa01c60b398b54adc8ed17d33ce0c96bd197c
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.