""" SAM2 Video Segmentation Space - SIMPLIFIED VERSION Removes background from videos by tracking specified objects. """ import gradio as gr import torch import numpy as np import cv2 import tempfile import json import os from typing import List, Tuple, Optional, Dict, Any from transformers import Sam2VideoModel, Sam2VideoProcessor from PIL import Image import spaces # Global model variables MODEL_NAME = "facebook/sam2.1-hiera-tiny" device = None model = None processor = None def initialize_model(): """Initialize SAM2 model and processor.""" global device, model, processor if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float32 # Use float32 for universal GPU compatibility print(f"CUDA available: {torch.cuda.is_available()}") print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}") elif torch.backends.mps.is_available(): device = torch.device("mps") dtype = torch.float32 else: device = torch.device("cpu") dtype = torch.float32 print(f"Loading SAM2 model on {device} with dtype {dtype}...") model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=dtype) processor = Sam2VideoProcessor.from_pretrained(MODEL_NAME) print("Model loaded successfully!") return device, model, processor def load_video_cv2(video_path): """Load video using OpenCV to preserve orientation.""" cap = cv2.VideoCapture(video_path) frames = [] fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 while cap.isOpened(): ret, frame = cap.read() if not ret: break # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) cap.release() return frames, {'fps': fps} @spaces.GPU def segment_video_advanced(video_file, points_json): """ Advanced video segmentation with multiple tracking points across frames. This allows re-initializing tracking at different frames when the object is lost. Critical for tracking small, fast-moving objects like soccer balls. Args: video_file: Video file path points_json: JSON string with format: { "object_id": "ball", "tracking_points": [ {"frame": 0, "x": 360, "y": 900}, {"frame": 50, "x": 380, "y": 850}, {"frame": 100, "x": 400, "y": 800} ], "remove_bg": true } Returns: (output_video_path, status_message_with_confidence) """ global device, model, processor if model is None: initialize_model() try: if video_file is None: return None, "❌ Error: No video file provided" video_path = str(video_file) if not os.path.exists(video_path): return None, f"❌ Error: Video file not found: {video_path}" # Parse points JSON try: config = json.loads(points_json) except json.JSONDecodeError as e: return None, f"❌ Error: Invalid JSON format: {str(e)}" object_id = config.get('object_id', 'object') tracking_points = config.get('tracking_points', []) remove_bg = config.get('remove_bg', True) if not tracking_points: return None, "❌ Error: No tracking points provided" print(f"Processing '{object_id}' with {len(tracking_points)} tracking points") # Load video video_frames, video_info = load_video_cv2(video_path) fps = video_info.get('fps', 30.0) first_frame = np.array(video_frames[0]) height, width = first_frame.shape[:2] # Initialize inference session dtype = torch.float32 inference_session = processor.init_video_session( video=video_frames, inference_device=device, dtype=dtype, ) # Add all tracking points for i, point in enumerate(tracking_points): frame_idx = int(point.get('frame', 0)) point_x = int(point.get('x', width // 2)) point_y = int(point.get('y', height // 2)) print(f" Point {i+1}/{len(tracking_points)}: frame {frame_idx}, ({point_x}, {point_y})") # Add annotation at this frame processor.add_inputs_to_inference_session( inference_session=inference_session, frame_idx=frame_idx, obj_ids=1, # Same object ID for all points input_points=[[[[point_x, point_y]]]], input_labels=[[[1]]], ) # Skip initial model inference - go straight to propagation # The propagation loop will handle this frame when it reaches it # Propagate through ALL frames explicitly (frame-by-frame) # This ensures proper bidirectional propagation from all tracking points video_segments = {} confidence_scores = [] print(f"Propagating through all {len(video_frames)} frames...") with torch.inference_mode(): for f_idx, frame_pil in enumerate(video_frames): pixel_values = None if inference_session.processed_frames is None or f_idx not in inference_session.processed_frames: # Let model handle dtype conversion automatically pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0] sam2_output = model( inference_session=inference_session, frame=pixel_values, frame_idx=f_idx ) H = inference_session.video_height W = inference_session.video_width pred_masks = sam2_output.pred_masks.detach().cpu() video_res_masks = processor.post_process_masks( [pred_masks], original_sizes=[[H, W]], binarize=False # Keep as float, not boolean )[0] video_segments[f_idx] = video_res_masks # Calculate confidence (ensure float conversion) mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks mask_confidence = float(mask_float.mean()) confidence_scores.append(mask_confidence) print(f"✅ Got masks for all {len(video_segments)} frames") # Calculate tracking quality mean_confidence = np.mean(confidence_scores) if confidence_scores else 0.0 min_confidence = np.min(confidence_scores) if confidence_scores else 0.0 # Create output video output_path = tempfile.mktemp(suffix=f"_{object_id}.mp4") fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for frame_idx, frame_pil in enumerate(video_frames): frame = np.array(frame_pil) if frame_idx in video_segments: # We have a mask for this frame - apply it mask = video_segments[frame_idx].cpu().numpy() if mask.ndim == 4: mask = mask[0] if mask.ndim == 3: mask = mask.max(axis=0) if mask.shape != (height, width): mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) mask_binary = (mask > 0.5).astype(np.uint8) if remove_bg: background = np.zeros_like(frame) mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2) frame = frame * mask_3d + background * (1 - mask_3d) # Note: If frame_idx not in video_segments, we keep the original frame # This shouldn't happen as SAM2 propagates to all frames frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() # Build status message with confidence status = { "status": "success", "object_id": object_id, "frames_processed": len(video_segments), "total_frames": len(video_frames), "tracking_points_used": len(tracking_points), "confidence": { "mean": round(mean_confidence, 3), "min": round(min_confidence, 3) } } status_msg = f"✅ Success! Processed {len(video_segments)}/{len(video_frames)} frames\n" status_msg += f" Tracking points: {len(tracking_points)}\n" status_msg += f" Confidence: mean={mean_confidence:.3f}, min={min_confidence:.3f}" if os.path.exists(output_path): return output_path, status_msg else: return None, "❌ Error: Output file was not created" except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error in segment_video_advanced: {error_details}") return None, f"❌ Error: {str(e)}" @spaces.GPU def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg): """Simple video segmentation with a single point.""" global device, model, processor if model is None: initialize_model() try: # Handle video_file - gr.File passes it as a string path directly if video_file is None: return None, "❌ Error: No video file provided" # gr.File returns the file path as a string video_path = str(video_file) if not os.path.exists(video_path): return None, f"❌ Error: Video file not found: {video_path}" print(f"Processing video from: {video_path}") # Convert inputs point_x = int(float(point_x)) point_y = int(float(point_y)) frame_idx = int(float(frame_idx)) # Load video using OpenCV to preserve orientation video_frames, video_info = load_video_cv2(video_path) fps = video_info.get('fps', 30.0) # Initialize inference session dtype = torch.float32 # Use float32 for universal compatibility inference_session = processor.init_video_session( video=video_frames, inference_device=device, dtype=dtype, ) # Add annotation processor.add_inputs_to_inference_session( inference_session=inference_session, frame_idx=frame_idx, obj_ids=1, input_points=[[[[point_x, point_y]]]], input_labels=[[[1]]], ) # Skip initial model inference - go straight to propagation # The propagation loop will handle this frame when it reaches it # Propagate through ALL frames explicitly (frame-by-frame) # This ensures proper bidirectional propagation from init_frame video_segments = {} print(f"Propagating through {len(video_frames)} frames from frame {frame_idx}...") with torch.inference_mode(): for f_idx, frame_pil in enumerate(video_frames): pixel_values = None if inference_session.processed_frames is None or f_idx not in inference_session.processed_frames: # Let model handle dtype conversion automatically pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0] sam2_output = model( inference_session=inference_session, frame=pixel_values, frame_idx=f_idx ) H = inference_session.video_height W = inference_session.video_width pred_masks = sam2_output.pred_masks.detach().cpu() video_res_masks = processor.post_process_masks( [pred_masks], original_sizes=[[H, W]], binarize=False # Keep as float, not boolean )[0] video_segments[f_idx] = video_res_masks print(f"✅ Got masks for {len(video_segments)} frames") # Create output video output_path = tempfile.mktemp(suffix=".mp4") first_frame = np.array(video_frames[0]) height, width = first_frame.shape[:2] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for frame_idx, frame_pil in enumerate(video_frames): frame = np.array(frame_pil) if frame_idx in video_segments: mask = video_segments[frame_idx].cpu().numpy() if mask.ndim == 4: mask = mask[0] if mask.ndim == 3: mask = mask.max(axis=0) if mask.shape != (height, width): mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) mask_binary = (mask > 0.5).astype(np.uint8) if remove_bg: background = np.zeros_like(frame) mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2) frame = frame * mask_3d + background * (1 - mask_3d) frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() # Return the video file path (Gradio will handle it) if os.path.exists(output_path): return output_path, f"✅ Success! Processed {len(video_segments)} frames" else: return None, f"❌ Error: Output file was not created" except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error in segment_video_simple: {error_details}") return None, f"❌ Error: {str(e)}" @spaces.GPU def segment_video_multi(video_file, objects_json, remove_bg=True): """ Multi-object video segmentation. Args: video_file: Video file path objects_json: JSON string with format: [ {"id": "ball", "init_frame": 0, "point_x": 360, "point_y": 640}, {"id": "player", "init_frame": 0, "point_x": 320, "point_y": 500} ] remove_bg: If True, remove background. If False, overlay masks on original video. Returns: Tuple: (video1_path, video2_path, status_json) - video1_path: First object video - video2_path: Second object video (or None if only one object) - status_json: Status message with confidence scores in JSON format """ global device, model, processor if model is None: initialize_model() try: # Parse input if video_file is None: return None, None, json.dumps({"error": "No video file provided"}) video_path = str(video_file) if not os.path.exists(video_path): return None, None, json.dumps({"error": f"Video file not found: {video_path}"}) # Parse objects JSON try: objects = json.loads(objects_json) except json.JSONDecodeError as e: return None, None, json.dumps({"error": f"Invalid JSON format: {str(e)}"}) if not isinstance(objects, list) or len(objects) == 0: return None, None, json.dumps({"error": "objects must be a non-empty list"}) print(f"Processing {len(objects)} objects from video: {video_path}") # Load video once video_frames, video_info = load_video_cv2(video_path) fps = video_info.get('fps', 30.0) first_frame = np.array(video_frames[0]) height, width = first_frame.shape[:2] # Process each object separately results = [] output_files = [] for obj_idx, obj in enumerate(objects): obj_id = obj.get('id', f'object_{obj_idx}') init_frame = int(obj.get('init_frame', 0)) point_x = int(obj.get('point_x', width // 2)) point_y = int(obj.get('point_y', height // 2)) print(f"Processing object '{obj_id}' at frame {init_frame}, point ({point_x}, {point_y})") try: # CRITICAL: Reinitialize model for each object to avoid dtype contamination # This is necessary because ZeroGPU may change model dtype between objects if obj_idx > 0: print(f" Reinitializing model for object {obj_idx+1}...") torch.cuda.empty_cache() model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=torch.float32) # Force ALL parameters to float32 for param in model.parameters(): param.data = param.data.to(torch.float32) # Ensure model is in correct dtype dtype = torch.float32 model.to(device, dtype=dtype) # Double-check all parameters are float32 for param in model.parameters(): if param.dtype != torch.float32: param.data = param.data.to(torch.float32) # Initialize inference session for this object inference_session = processor.init_video_session( video=video_frames, inference_device=device, dtype=dtype, ) # Add annotation for this object processor.add_inputs_to_inference_session( inference_session=inference_session, frame_idx=init_frame, obj_ids=1, input_points=[[[[point_x, point_y]]]], input_labels=[[[1]]], ) # Skip initial model inference - go straight to propagation # The propagation loop will handle init_frame when it reaches it # Use propagate_in_video_iterator for BIDIRECTIONAL propagation # According to SAM2 docs, this should propagate both forward AND backward # from the annotated frame (init_frame) video_segments = {} confidence_scores = [] print(f" Using propagate_in_video_iterator for bidirectional propagation from frame {init_frame}...") for sam2_output in model.propagate_in_video_iterator(inference_session): video_res_masks = processor.post_process_masks( [sam2_output.pred_masks], original_sizes=[[inference_session.video_height, inference_session.video_width]], binarize=False )[0] video_segments[sam2_output.frame_idx] = video_res_masks # Calculate confidence mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks confidence_scores.append(float(mask_float.mean())) print(f" ✅ Got masks for {len(video_segments)} frames (init_frame was {init_frame})") # Debug: Verify bidirectional propagation worked frames_before = sum(1 for f in video_segments.keys() if f < init_frame) frames_after = sum(1 for f in video_segments.keys() if f >= init_frame) print(f" 📊 Frames before init_frame: {frames_before}, after: {frames_after}") # Calculate tracking quality metrics mean_confidence = np.mean(confidence_scores) if confidence_scores else 0.0 min_confidence = np.min(confidence_scores) if confidence_scores else 0.0 # Create output video for this object output_path = tempfile.mktemp(suffix=f"_{obj_id}.mp4") fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for frame_idx, frame_pil in enumerate(video_frames): frame = np.array(frame_pil) if frame_idx in video_segments: mask = video_segments[frame_idx].cpu().numpy() if mask.ndim == 4: mask = mask[0] if mask.ndim == 3: mask = mask.max(axis=0) if mask.shape != (height, width): mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) mask_binary = (mask > 0.5).astype(np.uint8) mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2) if remove_bg: # Remove background - show only tracked object background = np.zeros_like(frame) frame = frame * mask_3d + background * (1 - mask_3d) else: # Overlay colored mask on original video # Create a colored overlay (e.g., semi-transparent green) overlay_color = np.array([0, 255, 0], dtype=np.uint8) # Green overlay = np.ones_like(frame) * overlay_color alpha = 0.5 # Transparency frame = (frame * (1 - alpha * mask_3d) + overlay * alpha * mask_3d).astype(np.uint8) frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() if os.path.exists(output_path): output_files.append(output_path) results.append({ "object_id": obj_id, "status": "success", "frames_processed": len(video_segments), "tracking_confidence": { "mean": round(mean_confidence, 3), "min": round(min_confidence, 3) } }) else: results.append({ "object_id": obj_id, "status": "error", "error": "Output file not created" }) except Exception as e: results.append({ "object_id": obj_id, "status": "error", "error": str(e) }) print(f"Error processing object '{obj_id}': {e}") # Return results status_json = json.dumps({ "status": "completed", "objects_processed": len(results), "results": results }, indent=2) # Return up to 2 videos (for Gradio UI compatibility) video1 = output_files[0] if len(output_files) > 0 else None video2 = output_files[1] if len(output_files) > 1 else None return video1, video2, status_json except Exception as e: import traceback error_details = traceback.format_exc() print(f"Error in segment_video_multi: {error_details}") return None, None, json.dumps({"error": str(e), "details": error_details}) # Create Gradio interface def create_app(): initialize_model() with gr.Blocks(title="SAM2 Video Background Remover") as app: gr.Markdown(""" # 🎥 SAM2 Video Background Remover Remove backgrounds from videos by tracking objects with Meta's SAM2. """) with gr.Tabs(): # Tab 1: Simple single-object tracking with gr.Tab("Single Point"): gr.Markdown(""" **Track one object at a time.** 1. Upload a video 2. Enter X, Y coordinates of the object to track 3. Click "Process Video" """) with gr.Row(): with gr.Column(): video_input_simple = gr.File( label="Upload Video", file_types=["video"] ) with gr.Row(): point_x = gr.Textbox(label="Point X", value="320") point_y = gr.Textbox(label="Point Y", value="240") frame_idx = gr.Textbox(label="Frame Index", value="0") remove_bg = gr.Checkbox(label="Remove Background", value=True) process_btn_simple = gr.Button("🎬 Process Video", variant="primary") with gr.Column(): output_video_simple = gr.File(label="Output Video") status_text_simple = gr.Textbox(label="Status", lines=3) process_btn_simple.click( fn=segment_video_simple, inputs=[video_input_simple, point_x, point_y, frame_idx, remove_bg], outputs=[output_video_simple, status_text_simple] ) # Tab 2: Advanced multi-point tracking with gr.Tab("Multi-Point (Advanced)"): gr.Markdown(""" **Track one object with MULTIPLE points across frames.** Perfect for objects that move fast or get occluded (like soccer balls). Re-initialize tracking at different frames to maintain accuracy. 1. Upload a video 2. Provide JSON with multiple tracking points 3. Click "Process Advanced" """) with gr.Row(): with gr.Column(): video_input_advanced = gr.File( label="Upload Video", file_types=["video"] ) points_advanced_json = gr.Textbox( label="Tracking Configuration (JSON)", lines=15, value=json.dumps({ "object_id": "ball", "tracking_points": [ {"frame": 0, "x": 360, "y": 900}, {"frame": 50, "x": 380, "y": 850}, {"frame": 100, "x": 400, "y": 800} ], "remove_bg": True }, indent=2), placeholder='See example for format' ) process_btn_advanced = gr.Button("🎬 Process Advanced", variant="primary") with gr.Column(): output_video_advanced = gr.File(label="Output Video") status_text_advanced = gr.Textbox(label="Status", lines=8) process_btn_advanced.click( fn=segment_video_advanced, inputs=[video_input_advanced, points_advanced_json], outputs=[output_video_advanced, status_text_advanced] ) gr.Markdown(""" ### 📖 Multi-Point Tracking Explained **When to use this:** - Ball goes off-screen and comes back - Ball gets occluded by player - Fast motion causes tracking to drift - Small object is hard to track continuously **How it works:** 1. Mark the ball at frame 0 where it's stationary 2. Mark again at frame 50 where it's visible after kick 3. Mark again at frame 100 where it lands 4. SAM2 propagates between these anchor points **Result:** Much more robust tracking across entire video! """) # Tab 3: Multi-object tracking with gr.Tab("Multiple Objects"): gr.Markdown(""" **Track multiple objects in one video (e.g., ball AND player).** 1. Upload a video 2. Provide objects JSON (default values are for testing with `original.mp4`) 3. Click "Process Multi-Object" You'll receive separate video files for each tracked object! **⚠️ TESTING NOTE**: Order swapped - PLAYER FIRST (init_frame=165), BALL SECOND (init_frame=0) to test if propagation works regardless of init_frame order. **Note**: Default JSON values are optimized for the test video `original.mp4` (portrait soccer video). """) with gr.Row(): with gr.Column(): video_input_multi = gr.File( label="Upload Video", file_types=["video"] ) objects_json = gr.Textbox( label="Objects JSON (PLAYER FIRST for testing!)", lines=12, value=json.dumps([ {"id": "player", "init_frame": 165, "point_x": 180, "point_y": 320}, {"id": "ball", "init_frame": 0, "point_x": 360, "point_y": 640} ], indent=2), placeholder='[{"id": "player", "init_frame": 165, "point_x": 180, "point_y": 320}]' ) remove_bg_multi = gr.Checkbox( label="Remove Background", value=True, info="If checked, shows only tracked objects. If unchecked, overlays masks on original video." ) process_btn_multi = gr.Button("🎬 Process Multi-Object", variant="primary") with gr.Column(): output_video_multi_1 = gr.File(label="Output Video 1") output_video_multi_2 = gr.File(label="Output Video 2") status_text_multi = gr.Textbox(label="Status (JSON)", lines=15) process_btn_multi.click( fn=segment_video_multi, inputs=[video_input_multi, objects_json, remove_bg_multi], outputs=[output_video_multi_1, output_video_multi_2, status_text_multi] ) gr.Markdown(""" --- ### 📐 Coordinate System - **Coordinates (X, Y)**: Absolute pixel coordinates - **Origin**: Top-left corner (0, 0) - **Portrait video (720x1280)**: Center ≈ X=360, Y=640 - **Landscape video (1920x1080)**: Center ≈ X=960, Y=540 ### 🎯 Tips for Best Results - **Frame Index**: Choose a frame where the object is clearly visible and unoccluded - **Point Selection**: Click on a distinctive part of the object - **Multiple Points**: For complex shapes (like players), use multiple objects with different points - **Tracking Quality**: The API returns confidence scores for each object ### 📹 Output Format - **Format**: MP4 (H.264) - **Background**: Black (RGB: 0, 0, 0) - can be made transparent in post-processing - **FPS**: Same as input video - **Resolution**: Same as input video ### ⚡ Performance - Processing time: ~30-60 seconds for typical videos on GPU - Videos are processed sequentially (one at a time in multi-object mode) - GPU acceleration via ZeroGPU """) return app if __name__ == "__main__": app = create_app() app.launch(share=True, show_error=True)