chaskick / app.py
Mirko Trasciatti
Add Remove Background checkbox to Multi-Object tab with overlay option
cd7c867
"""
SAM2 Video Segmentation Space - SIMPLIFIED VERSION
Removes background from videos by tracking specified objects.
"""
import gradio as gr
import torch
import numpy as np
import cv2
import tempfile
import json
import os
from typing import List, Tuple, Optional, Dict, Any
from transformers import Sam2VideoModel, Sam2VideoProcessor
from PIL import Image
import spaces
# Global model variables
MODEL_NAME = "facebook/sam2.1-hiera-tiny"
device = None
model = None
processor = None
def initialize_model():
"""Initialize SAM2 model and processor."""
global device, model, processor
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float32 # Use float32 for universal GPU compatibility
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")
elif torch.backends.mps.is_available():
device = torch.device("mps")
dtype = torch.float32
else:
device = torch.device("cpu")
dtype = torch.float32
print(f"Loading SAM2 model on {device} with dtype {dtype}...")
model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=dtype)
processor = Sam2VideoProcessor.from_pretrained(MODEL_NAME)
print("Model loaded successfully!")
return device, model, processor
def load_video_cv2(video_path):
"""Load video using OpenCV to preserve orientation."""
cap = cv2.VideoCapture(video_path)
frames = []
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
cap.release()
return frames, {'fps': fps}
@spaces.GPU
def segment_video_advanced(video_file, points_json):
"""
Advanced video segmentation with multiple tracking points across frames.
This allows re-initializing tracking at different frames when the object is lost.
Critical for tracking small, fast-moving objects like soccer balls.
Args:
video_file: Video file path
points_json: JSON string with format:
{
"object_id": "ball",
"tracking_points": [
{"frame": 0, "x": 360, "y": 900},
{"frame": 50, "x": 380, "y": 850},
{"frame": 100, "x": 400, "y": 800}
],
"remove_bg": true
}
Returns:
(output_video_path, status_message_with_confidence)
"""
global device, model, processor
if model is None:
initialize_model()
try:
if video_file is None:
return None, "❌ Error: No video file provided"
video_path = str(video_file)
if not os.path.exists(video_path):
return None, f"❌ Error: Video file not found: {video_path}"
# Parse points JSON
try:
config = json.loads(points_json)
except json.JSONDecodeError as e:
return None, f"❌ Error: Invalid JSON format: {str(e)}"
object_id = config.get('object_id', 'object')
tracking_points = config.get('tracking_points', [])
remove_bg = config.get('remove_bg', True)
if not tracking_points:
return None, "❌ Error: No tracking points provided"
print(f"Processing '{object_id}' with {len(tracking_points)} tracking points")
# Load video
video_frames, video_info = load_video_cv2(video_path)
fps = video_info.get('fps', 30.0)
first_frame = np.array(video_frames[0])
height, width = first_frame.shape[:2]
# Initialize inference session
dtype = torch.float32
inference_session = processor.init_video_session(
video=video_frames,
inference_device=device,
dtype=dtype,
)
# Add all tracking points
for i, point in enumerate(tracking_points):
frame_idx = int(point.get('frame', 0))
point_x = int(point.get('x', width // 2))
point_y = int(point.get('y', height // 2))
print(f" Point {i+1}/{len(tracking_points)}: frame {frame_idx}, ({point_x}, {point_y})")
# Add annotation at this frame
processor.add_inputs_to_inference_session(
inference_session=inference_session,
frame_idx=frame_idx,
obj_ids=1, # Same object ID for all points
input_points=[[[[point_x, point_y]]]],
input_labels=[[[1]]],
)
# Skip initial model inference - go straight to propagation
# The propagation loop will handle this frame when it reaches it
# Propagate through ALL frames explicitly (frame-by-frame)
# This ensures proper bidirectional propagation from all tracking points
video_segments = {}
confidence_scores = []
print(f"Propagating through all {len(video_frames)} frames...")
with torch.inference_mode():
for f_idx, frame_pil in enumerate(video_frames):
pixel_values = None
if inference_session.processed_frames is None or f_idx not in inference_session.processed_frames:
# Let model handle dtype conversion automatically
pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0]
sam2_output = model(
inference_session=inference_session,
frame=pixel_values,
frame_idx=f_idx
)
H = inference_session.video_height
W = inference_session.video_width
pred_masks = sam2_output.pred_masks.detach().cpu()
video_res_masks = processor.post_process_masks(
[pred_masks],
original_sizes=[[H, W]],
binarize=False # Keep as float, not boolean
)[0]
video_segments[f_idx] = video_res_masks
# Calculate confidence (ensure float conversion)
mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
mask_confidence = float(mask_float.mean())
confidence_scores.append(mask_confidence)
print(f"βœ… Got masks for all {len(video_segments)} frames")
# Calculate tracking quality
mean_confidence = np.mean(confidence_scores) if confidence_scores else 0.0
min_confidence = np.min(confidence_scores) if confidence_scores else 0.0
# Create output video
output_path = tempfile.mktemp(suffix=f"_{object_id}.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for frame_idx, frame_pil in enumerate(video_frames):
frame = np.array(frame_pil)
if frame_idx in video_segments:
# We have a mask for this frame - apply it
mask = video_segments[frame_idx].cpu().numpy()
if mask.ndim == 4:
mask = mask[0]
if mask.ndim == 3:
mask = mask.max(axis=0)
if mask.shape != (height, width):
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
mask_binary = (mask > 0.5).astype(np.uint8)
if remove_bg:
background = np.zeros_like(frame)
mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2)
frame = frame * mask_3d + background * (1 - mask_3d)
# Note: If frame_idx not in video_segments, we keep the original frame
# This shouldn't happen as SAM2 propagates to all frames
frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
out.release()
# Build status message with confidence
status = {
"status": "success",
"object_id": object_id,
"frames_processed": len(video_segments),
"total_frames": len(video_frames),
"tracking_points_used": len(tracking_points),
"confidence": {
"mean": round(mean_confidence, 3),
"min": round(min_confidence, 3)
}
}
status_msg = f"βœ… Success! Processed {len(video_segments)}/{len(video_frames)} frames\n"
status_msg += f" Tracking points: {len(tracking_points)}\n"
status_msg += f" Confidence: mean={mean_confidence:.3f}, min={min_confidence:.3f}"
if os.path.exists(output_path):
return output_path, status_msg
else:
return None, "❌ Error: Output file was not created"
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in segment_video_advanced: {error_details}")
return None, f"❌ Error: {str(e)}"
@spaces.GPU
def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg):
"""Simple video segmentation with a single point."""
global device, model, processor
if model is None:
initialize_model()
try:
# Handle video_file - gr.File passes it as a string path directly
if video_file is None:
return None, "❌ Error: No video file provided"
# gr.File returns the file path as a string
video_path = str(video_file)
if not os.path.exists(video_path):
return None, f"❌ Error: Video file not found: {video_path}"
print(f"Processing video from: {video_path}")
# Convert inputs
point_x = int(float(point_x))
point_y = int(float(point_y))
frame_idx = int(float(frame_idx))
# Load video using OpenCV to preserve orientation
video_frames, video_info = load_video_cv2(video_path)
fps = video_info.get('fps', 30.0)
# Initialize inference session
dtype = torch.float32 # Use float32 for universal compatibility
inference_session = processor.init_video_session(
video=video_frames,
inference_device=device,
dtype=dtype,
)
# Add annotation
processor.add_inputs_to_inference_session(
inference_session=inference_session,
frame_idx=frame_idx,
obj_ids=1,
input_points=[[[[point_x, point_y]]]],
input_labels=[[[1]]],
)
# Skip initial model inference - go straight to propagation
# The propagation loop will handle this frame when it reaches it
# Propagate through ALL frames explicitly (frame-by-frame)
# This ensures proper bidirectional propagation from init_frame
video_segments = {}
print(f"Propagating through {len(video_frames)} frames from frame {frame_idx}...")
with torch.inference_mode():
for f_idx, frame_pil in enumerate(video_frames):
pixel_values = None
if inference_session.processed_frames is None or f_idx not in inference_session.processed_frames:
# Let model handle dtype conversion automatically
pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0]
sam2_output = model(
inference_session=inference_session,
frame=pixel_values,
frame_idx=f_idx
)
H = inference_session.video_height
W = inference_session.video_width
pred_masks = sam2_output.pred_masks.detach().cpu()
video_res_masks = processor.post_process_masks(
[pred_masks],
original_sizes=[[H, W]],
binarize=False # Keep as float, not boolean
)[0]
video_segments[f_idx] = video_res_masks
print(f"βœ… Got masks for {len(video_segments)} frames")
# Create output video
output_path = tempfile.mktemp(suffix=".mp4")
first_frame = np.array(video_frames[0])
height, width = first_frame.shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for frame_idx, frame_pil in enumerate(video_frames):
frame = np.array(frame_pil)
if frame_idx in video_segments:
mask = video_segments[frame_idx].cpu().numpy()
if mask.ndim == 4:
mask = mask[0]
if mask.ndim == 3:
mask = mask.max(axis=0)
if mask.shape != (height, width):
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
mask_binary = (mask > 0.5).astype(np.uint8)
if remove_bg:
background = np.zeros_like(frame)
mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2)
frame = frame * mask_3d + background * (1 - mask_3d)
frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
out.release()
# Return the video file path (Gradio will handle it)
if os.path.exists(output_path):
return output_path, f"βœ… Success! Processed {len(video_segments)} frames"
else:
return None, f"❌ Error: Output file was not created"
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in segment_video_simple: {error_details}")
return None, f"❌ Error: {str(e)}"
@spaces.GPU
def segment_video_multi(video_file, objects_json, remove_bg=True):
"""
Multi-object video segmentation.
Args:
video_file: Video file path
objects_json: JSON string with format:
[
{"id": "ball", "init_frame": 0, "point_x": 360, "point_y": 640},
{"id": "player", "init_frame": 0, "point_x": 320, "point_y": 500}
]
remove_bg: If True, remove background. If False, overlay masks on original video.
Returns:
Tuple: (video1_path, video2_path, status_json)
- video1_path: First object video
- video2_path: Second object video (or None if only one object)
- status_json: Status message with confidence scores in JSON format
"""
global device, model, processor
if model is None:
initialize_model()
try:
# Parse input
if video_file is None:
return None, None, json.dumps({"error": "No video file provided"})
video_path = str(video_file)
if not os.path.exists(video_path):
return None, None, json.dumps({"error": f"Video file not found: {video_path}"})
# Parse objects JSON
try:
objects = json.loads(objects_json)
except json.JSONDecodeError as e:
return None, None, json.dumps({"error": f"Invalid JSON format: {str(e)}"})
if not isinstance(objects, list) or len(objects) == 0:
return None, None, json.dumps({"error": "objects must be a non-empty list"})
print(f"Processing {len(objects)} objects from video: {video_path}")
# Load video once
video_frames, video_info = load_video_cv2(video_path)
fps = video_info.get('fps', 30.0)
first_frame = np.array(video_frames[0])
height, width = first_frame.shape[:2]
# Process each object separately
results = []
output_files = []
for obj_idx, obj in enumerate(objects):
obj_id = obj.get('id', f'object_{obj_idx}')
init_frame = int(obj.get('init_frame', 0))
point_x = int(obj.get('point_x', width // 2))
point_y = int(obj.get('point_y', height // 2))
print(f"Processing object '{obj_id}' at frame {init_frame}, point ({point_x}, {point_y})")
try:
# CRITICAL: Reinitialize model for each object to avoid dtype contamination
# This is necessary because ZeroGPU may change model dtype between objects
if obj_idx > 0:
print(f" Reinitializing model for object {obj_idx+1}...")
torch.cuda.empty_cache()
model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=torch.float32)
# Force ALL parameters to float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
# Ensure model is in correct dtype
dtype = torch.float32
model.to(device, dtype=dtype)
# Double-check all parameters are float32
for param in model.parameters():
if param.dtype != torch.float32:
param.data = param.data.to(torch.float32)
# Initialize inference session for this object
inference_session = processor.init_video_session(
video=video_frames,
inference_device=device,
dtype=dtype,
)
# Add annotation for this object
processor.add_inputs_to_inference_session(
inference_session=inference_session,
frame_idx=init_frame,
obj_ids=1,
input_points=[[[[point_x, point_y]]]],
input_labels=[[[1]]],
)
# Skip initial model inference - go straight to propagation
# The propagation loop will handle init_frame when it reaches it
# Use propagate_in_video_iterator for BIDIRECTIONAL propagation
# According to SAM2 docs, this should propagate both forward AND backward
# from the annotated frame (init_frame)
video_segments = {}
confidence_scores = []
print(f" Using propagate_in_video_iterator for bidirectional propagation from frame {init_frame}...")
for sam2_output in model.propagate_in_video_iterator(inference_session):
video_res_masks = processor.post_process_masks(
[sam2_output.pred_masks],
original_sizes=[[inference_session.video_height, inference_session.video_width]],
binarize=False
)[0]
video_segments[sam2_output.frame_idx] = video_res_masks
# Calculate confidence
mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks
confidence_scores.append(float(mask_float.mean()))
print(f" βœ… Got masks for {len(video_segments)} frames (init_frame was {init_frame})")
# Debug: Verify bidirectional propagation worked
frames_before = sum(1 for f in video_segments.keys() if f < init_frame)
frames_after = sum(1 for f in video_segments.keys() if f >= init_frame)
print(f" πŸ“Š Frames before init_frame: {frames_before}, after: {frames_after}")
# Calculate tracking quality metrics
mean_confidence = np.mean(confidence_scores) if confidence_scores else 0.0
min_confidence = np.min(confidence_scores) if confidence_scores else 0.0
# Create output video for this object
output_path = tempfile.mktemp(suffix=f"_{obj_id}.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for frame_idx, frame_pil in enumerate(video_frames):
frame = np.array(frame_pil)
if frame_idx in video_segments:
mask = video_segments[frame_idx].cpu().numpy()
if mask.ndim == 4:
mask = mask[0]
if mask.ndim == 3:
mask = mask.max(axis=0)
if mask.shape != (height, width):
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
mask_binary = (mask > 0.5).astype(np.uint8)
mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2)
if remove_bg:
# Remove background - show only tracked object
background = np.zeros_like(frame)
frame = frame * mask_3d + background * (1 - mask_3d)
else:
# Overlay colored mask on original video
# Create a colored overlay (e.g., semi-transparent green)
overlay_color = np.array([0, 255, 0], dtype=np.uint8) # Green
overlay = np.ones_like(frame) * overlay_color
alpha = 0.5 # Transparency
frame = (frame * (1 - alpha * mask_3d) + overlay * alpha * mask_3d).astype(np.uint8)
frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR)
out.write(frame_bgr)
out.release()
if os.path.exists(output_path):
output_files.append(output_path)
results.append({
"object_id": obj_id,
"status": "success",
"frames_processed": len(video_segments),
"tracking_confidence": {
"mean": round(mean_confidence, 3),
"min": round(min_confidence, 3)
}
})
else:
results.append({
"object_id": obj_id,
"status": "error",
"error": "Output file not created"
})
except Exception as e:
results.append({
"object_id": obj_id,
"status": "error",
"error": str(e)
})
print(f"Error processing object '{obj_id}': {e}")
# Return results
status_json = json.dumps({
"status": "completed",
"objects_processed": len(results),
"results": results
}, indent=2)
# Return up to 2 videos (for Gradio UI compatibility)
video1 = output_files[0] if len(output_files) > 0 else None
video2 = output_files[1] if len(output_files) > 1 else None
return video1, video2, status_json
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Error in segment_video_multi: {error_details}")
return None, None, json.dumps({"error": str(e), "details": error_details})
# Create Gradio interface
def create_app():
initialize_model()
with gr.Blocks(title="SAM2 Video Background Remover") as app:
gr.Markdown("""
# πŸŽ₯ SAM2 Video Background Remover
Remove backgrounds from videos by tracking objects with Meta's SAM2.
""")
with gr.Tabs():
# Tab 1: Simple single-object tracking
with gr.Tab("Single Point"):
gr.Markdown("""
**Track one object at a time.**
1. Upload a video
2. Enter X, Y coordinates of the object to track
3. Click "Process Video"
""")
with gr.Row():
with gr.Column():
video_input_simple = gr.File(
label="Upload Video",
file_types=["video"]
)
with gr.Row():
point_x = gr.Textbox(label="Point X", value="320")
point_y = gr.Textbox(label="Point Y", value="240")
frame_idx = gr.Textbox(label="Frame Index", value="0")
remove_bg = gr.Checkbox(label="Remove Background", value=True)
process_btn_simple = gr.Button("🎬 Process Video", variant="primary")
with gr.Column():
output_video_simple = gr.File(label="Output Video")
status_text_simple = gr.Textbox(label="Status", lines=3)
process_btn_simple.click(
fn=segment_video_simple,
inputs=[video_input_simple, point_x, point_y, frame_idx, remove_bg],
outputs=[output_video_simple, status_text_simple]
)
# Tab 2: Advanced multi-point tracking
with gr.Tab("Multi-Point (Advanced)"):
gr.Markdown("""
**Track one object with MULTIPLE points across frames.**
Perfect for objects that move fast or get occluded (like soccer balls).
Re-initialize tracking at different frames to maintain accuracy.
1. Upload a video
2. Provide JSON with multiple tracking points
3. Click "Process Advanced"
""")
with gr.Row():
with gr.Column():
video_input_advanced = gr.File(
label="Upload Video",
file_types=["video"]
)
points_advanced_json = gr.Textbox(
label="Tracking Configuration (JSON)",
lines=15,
value=json.dumps({
"object_id": "ball",
"tracking_points": [
{"frame": 0, "x": 360, "y": 900},
{"frame": 50, "x": 380, "y": 850},
{"frame": 100, "x": 400, "y": 800}
],
"remove_bg": True
}, indent=2),
placeholder='See example for format'
)
process_btn_advanced = gr.Button("🎬 Process Advanced", variant="primary")
with gr.Column():
output_video_advanced = gr.File(label="Output Video")
status_text_advanced = gr.Textbox(label="Status", lines=8)
process_btn_advanced.click(
fn=segment_video_advanced,
inputs=[video_input_advanced, points_advanced_json],
outputs=[output_video_advanced, status_text_advanced]
)
gr.Markdown("""
### πŸ“– Multi-Point Tracking Explained
**When to use this:**
- Ball goes off-screen and comes back
- Ball gets occluded by player
- Fast motion causes tracking to drift
- Small object is hard to track continuously
**How it works:**
1. Mark the ball at frame 0 where it's stationary
2. Mark again at frame 50 where it's visible after kick
3. Mark again at frame 100 where it lands
4. SAM2 propagates between these anchor points
**Result:** Much more robust tracking across entire video!
""")
# Tab 3: Multi-object tracking
with gr.Tab("Multiple Objects"):
gr.Markdown("""
**Track multiple objects in one video (e.g., ball AND player).**
1. Upload a video
2. Provide objects JSON (default values are for testing with `original.mp4`)
3. Click "Process Multi-Object"
You'll receive separate video files for each tracked object!
**⚠️ TESTING NOTE**: Order swapped - PLAYER FIRST (init_frame=165), BALL SECOND (init_frame=0) to test if propagation works regardless of init_frame order.
**Note**: Default JSON values are optimized for the test video `original.mp4` (portrait soccer video).
""")
with gr.Row():
with gr.Column():
video_input_multi = gr.File(
label="Upload Video",
file_types=["video"]
)
objects_json = gr.Textbox(
label="Objects JSON (PLAYER FIRST for testing!)",
lines=12,
value=json.dumps([
{"id": "player", "init_frame": 165, "point_x": 180, "point_y": 320},
{"id": "ball", "init_frame": 0, "point_x": 360, "point_y": 640}
], indent=2),
placeholder='[{"id": "player", "init_frame": 165, "point_x": 180, "point_y": 320}]'
)
remove_bg_multi = gr.Checkbox(
label="Remove Background",
value=True,
info="If checked, shows only tracked objects. If unchecked, overlays masks on original video."
)
process_btn_multi = gr.Button("🎬 Process Multi-Object", variant="primary")
with gr.Column():
output_video_multi_1 = gr.File(label="Output Video 1")
output_video_multi_2 = gr.File(label="Output Video 2")
status_text_multi = gr.Textbox(label="Status (JSON)", lines=15)
process_btn_multi.click(
fn=segment_video_multi,
inputs=[video_input_multi, objects_json, remove_bg_multi],
outputs=[output_video_multi_1, output_video_multi_2, status_text_multi]
)
gr.Markdown("""
---
### πŸ“ Coordinate System
- **Coordinates (X, Y)**: Absolute pixel coordinates
- **Origin**: Top-left corner (0, 0)
- **Portrait video (720x1280)**: Center β‰ˆ X=360, Y=640
- **Landscape video (1920x1080)**: Center β‰ˆ X=960, Y=540
### 🎯 Tips for Best Results
- **Frame Index**: Choose a frame where the object is clearly visible and unoccluded
- **Point Selection**: Click on a distinctive part of the object
- **Multiple Points**: For complex shapes (like players), use multiple objects with different points
- **Tracking Quality**: The API returns confidence scores for each object
### πŸ“Ή Output Format
- **Format**: MP4 (H.264)
- **Background**: Black (RGB: 0, 0, 0) - can be made transparent in post-processing
- **FPS**: Same as input video
- **Resolution**: Same as input video
### ⚑ Performance
- Processing time: ~30-60 seconds for typical videos on GPU
- Videos are processed sequentially (one at a time in multi-object mode)
- GPU acceleration via ZeroGPU
""")
return app
if __name__ == "__main__":
app = create_app()
app.launch(share=True, show_error=True)