|
|
""" |
|
|
SAM2 Video Segmentation Space - SIMPLIFIED VERSION |
|
|
Removes background from videos by tracking specified objects. |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import tempfile |
|
|
import json |
|
|
import os |
|
|
from typing import List, Tuple, Optional, Dict, Any |
|
|
from transformers import Sam2VideoModel, Sam2VideoProcessor |
|
|
from PIL import Image |
|
|
import spaces |
|
|
|
|
|
|
|
|
MODEL_NAME = "facebook/sam2.1-hiera-tiny" |
|
|
device = None |
|
|
model = None |
|
|
processor = None |
|
|
|
|
|
|
|
|
def initialize_model(): |
|
|
"""Initialize SAM2 model and processor.""" |
|
|
global device, model, processor |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device("cuda") |
|
|
dtype = torch.float32 |
|
|
print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
print(f"CUDA device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}") |
|
|
elif torch.backends.mps.is_available(): |
|
|
device = torch.device("mps") |
|
|
dtype = torch.float32 |
|
|
else: |
|
|
device = torch.device("cpu") |
|
|
dtype = torch.float32 |
|
|
|
|
|
print(f"Loading SAM2 model on {device} with dtype {dtype}...") |
|
|
|
|
|
model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=dtype) |
|
|
processor = Sam2VideoProcessor.from_pretrained(MODEL_NAME) |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
return device, model, processor |
|
|
|
|
|
|
|
|
def load_video_cv2(video_path): |
|
|
"""Load video using OpenCV to preserve orientation.""" |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
frames = [] |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 |
|
|
|
|
|
while cap.isOpened(): |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
frames.append(Image.fromarray(frame_rgb)) |
|
|
|
|
|
cap.release() |
|
|
return frames, {'fps': fps} |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def segment_video_advanced(video_file, points_json): |
|
|
""" |
|
|
Advanced video segmentation with multiple tracking points across frames. |
|
|
|
|
|
This allows re-initializing tracking at different frames when the object is lost. |
|
|
Critical for tracking small, fast-moving objects like soccer balls. |
|
|
|
|
|
Args: |
|
|
video_file: Video file path |
|
|
points_json: JSON string with format: |
|
|
{ |
|
|
"object_id": "ball", |
|
|
"tracking_points": [ |
|
|
{"frame": 0, "x": 360, "y": 900}, |
|
|
{"frame": 50, "x": 380, "y": 850}, |
|
|
{"frame": 100, "x": 400, "y": 800} |
|
|
], |
|
|
"remove_bg": true |
|
|
} |
|
|
|
|
|
Returns: |
|
|
(output_video_path, status_message_with_confidence) |
|
|
""" |
|
|
global device, model, processor |
|
|
|
|
|
if model is None: |
|
|
initialize_model() |
|
|
|
|
|
try: |
|
|
if video_file is None: |
|
|
return None, "β Error: No video file provided" |
|
|
|
|
|
video_path = str(video_file) |
|
|
|
|
|
if not os.path.exists(video_path): |
|
|
return None, f"β Error: Video file not found: {video_path}" |
|
|
|
|
|
|
|
|
try: |
|
|
config = json.loads(points_json) |
|
|
except json.JSONDecodeError as e: |
|
|
return None, f"β Error: Invalid JSON format: {str(e)}" |
|
|
|
|
|
object_id = config.get('object_id', 'object') |
|
|
tracking_points = config.get('tracking_points', []) |
|
|
remove_bg = config.get('remove_bg', True) |
|
|
|
|
|
if not tracking_points: |
|
|
return None, "β Error: No tracking points provided" |
|
|
|
|
|
print(f"Processing '{object_id}' with {len(tracking_points)} tracking points") |
|
|
|
|
|
|
|
|
video_frames, video_info = load_video_cv2(video_path) |
|
|
fps = video_info.get('fps', 30.0) |
|
|
first_frame = np.array(video_frames[0]) |
|
|
height, width = first_frame.shape[:2] |
|
|
|
|
|
|
|
|
dtype = torch.float32 |
|
|
inference_session = processor.init_video_session( |
|
|
video=video_frames, |
|
|
inference_device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
|
|
|
for i, point in enumerate(tracking_points): |
|
|
frame_idx = int(point.get('frame', 0)) |
|
|
point_x = int(point.get('x', width // 2)) |
|
|
point_y = int(point.get('y', height // 2)) |
|
|
|
|
|
print(f" Point {i+1}/{len(tracking_points)}: frame {frame_idx}, ({point_x}, {point_y})") |
|
|
|
|
|
|
|
|
processor.add_inputs_to_inference_session( |
|
|
inference_session=inference_session, |
|
|
frame_idx=frame_idx, |
|
|
obj_ids=1, |
|
|
input_points=[[[[point_x, point_y]]]], |
|
|
input_labels=[[[1]]], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_segments = {} |
|
|
confidence_scores = [] |
|
|
|
|
|
print(f"Propagating through all {len(video_frames)} frames...") |
|
|
|
|
|
with torch.inference_mode(): |
|
|
for f_idx, frame_pil in enumerate(video_frames): |
|
|
pixel_values = None |
|
|
if inference_session.processed_frames is None or f_idx not in inference_session.processed_frames: |
|
|
|
|
|
pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0] |
|
|
|
|
|
sam2_output = model( |
|
|
inference_session=inference_session, |
|
|
frame=pixel_values, |
|
|
frame_idx=f_idx |
|
|
) |
|
|
|
|
|
H = inference_session.video_height |
|
|
W = inference_session.video_width |
|
|
pred_masks = sam2_output.pred_masks.detach().cpu() |
|
|
video_res_masks = processor.post_process_masks( |
|
|
[pred_masks], |
|
|
original_sizes=[[H, W]], |
|
|
binarize=False |
|
|
)[0] |
|
|
|
|
|
video_segments[f_idx] = video_res_masks |
|
|
|
|
|
|
|
|
mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks |
|
|
mask_confidence = float(mask_float.mean()) |
|
|
confidence_scores.append(mask_confidence) |
|
|
|
|
|
print(f"β
Got masks for all {len(video_segments)} frames") |
|
|
|
|
|
|
|
|
mean_confidence = np.mean(confidence_scores) if confidence_scores else 0.0 |
|
|
min_confidence = np.min(confidence_scores) if confidence_scores else 0.0 |
|
|
|
|
|
|
|
|
output_path = tempfile.mktemp(suffix=f"_{object_id}.mp4") |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
for frame_idx, frame_pil in enumerate(video_frames): |
|
|
frame = np.array(frame_pil) |
|
|
|
|
|
if frame_idx in video_segments: |
|
|
|
|
|
mask = video_segments[frame_idx].cpu().numpy() |
|
|
|
|
|
if mask.ndim == 4: |
|
|
mask = mask[0] |
|
|
if mask.ndim == 3: |
|
|
mask = mask.max(axis=0) |
|
|
|
|
|
if mask.shape != (height, width): |
|
|
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
mask_binary = (mask > 0.5).astype(np.uint8) |
|
|
|
|
|
if remove_bg: |
|
|
background = np.zeros_like(frame) |
|
|
mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2) |
|
|
frame = frame * mask_3d + background * (1 - mask_3d) |
|
|
|
|
|
|
|
|
|
|
|
frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR) |
|
|
out.write(frame_bgr) |
|
|
|
|
|
out.release() |
|
|
|
|
|
|
|
|
status = { |
|
|
"status": "success", |
|
|
"object_id": object_id, |
|
|
"frames_processed": len(video_segments), |
|
|
"total_frames": len(video_frames), |
|
|
"tracking_points_used": len(tracking_points), |
|
|
"confidence": { |
|
|
"mean": round(mean_confidence, 3), |
|
|
"min": round(min_confidence, 3) |
|
|
} |
|
|
} |
|
|
|
|
|
status_msg = f"β
Success! Processed {len(video_segments)}/{len(video_frames)} frames\n" |
|
|
status_msg += f" Tracking points: {len(tracking_points)}\n" |
|
|
status_msg += f" Confidence: mean={mean_confidence:.3f}, min={min_confidence:.3f}" |
|
|
|
|
|
if os.path.exists(output_path): |
|
|
return output_path, status_msg |
|
|
else: |
|
|
return None, "β Error: Output file was not created" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_details = traceback.format_exc() |
|
|
print(f"Error in segment_video_advanced: {error_details}") |
|
|
return None, f"β Error: {str(e)}" |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def segment_video_simple(video_file, point_x, point_y, frame_idx, remove_bg): |
|
|
"""Simple video segmentation with a single point.""" |
|
|
global device, model, processor |
|
|
|
|
|
if model is None: |
|
|
initialize_model() |
|
|
|
|
|
try: |
|
|
|
|
|
if video_file is None: |
|
|
return None, "β Error: No video file provided" |
|
|
|
|
|
|
|
|
video_path = str(video_file) |
|
|
|
|
|
if not os.path.exists(video_path): |
|
|
return None, f"β Error: Video file not found: {video_path}" |
|
|
|
|
|
print(f"Processing video from: {video_path}") |
|
|
|
|
|
|
|
|
point_x = int(float(point_x)) |
|
|
point_y = int(float(point_y)) |
|
|
frame_idx = int(float(frame_idx)) |
|
|
|
|
|
|
|
|
video_frames, video_info = load_video_cv2(video_path) |
|
|
fps = video_info.get('fps', 30.0) |
|
|
|
|
|
|
|
|
dtype = torch.float32 |
|
|
inference_session = processor.init_video_session( |
|
|
video=video_frames, |
|
|
inference_device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
|
|
|
processor.add_inputs_to_inference_session( |
|
|
inference_session=inference_session, |
|
|
frame_idx=frame_idx, |
|
|
obj_ids=1, |
|
|
input_points=[[[[point_x, point_y]]]], |
|
|
input_labels=[[[1]]], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_segments = {} |
|
|
print(f"Propagating through {len(video_frames)} frames from frame {frame_idx}...") |
|
|
|
|
|
with torch.inference_mode(): |
|
|
for f_idx, frame_pil in enumerate(video_frames): |
|
|
pixel_values = None |
|
|
if inference_session.processed_frames is None or f_idx not in inference_session.processed_frames: |
|
|
|
|
|
pixel_values = processor(images=frame_pil, device=device, return_tensors="pt").pixel_values[0] |
|
|
|
|
|
sam2_output = model( |
|
|
inference_session=inference_session, |
|
|
frame=pixel_values, |
|
|
frame_idx=f_idx |
|
|
) |
|
|
|
|
|
H = inference_session.video_height |
|
|
W = inference_session.video_width |
|
|
pred_masks = sam2_output.pred_masks.detach().cpu() |
|
|
video_res_masks = processor.post_process_masks( |
|
|
[pred_masks], |
|
|
original_sizes=[[H, W]], |
|
|
binarize=False |
|
|
)[0] |
|
|
|
|
|
video_segments[f_idx] = video_res_masks |
|
|
|
|
|
print(f"β
Got masks for {len(video_segments)} frames") |
|
|
|
|
|
|
|
|
output_path = tempfile.mktemp(suffix=".mp4") |
|
|
first_frame = np.array(video_frames[0]) |
|
|
height, width = first_frame.shape[:2] |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
for frame_idx, frame_pil in enumerate(video_frames): |
|
|
frame = np.array(frame_pil) |
|
|
|
|
|
if frame_idx in video_segments: |
|
|
mask = video_segments[frame_idx].cpu().numpy() |
|
|
|
|
|
if mask.ndim == 4: |
|
|
mask = mask[0] |
|
|
if mask.ndim == 3: |
|
|
mask = mask.max(axis=0) |
|
|
|
|
|
if mask.shape != (height, width): |
|
|
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
mask_binary = (mask > 0.5).astype(np.uint8) |
|
|
|
|
|
if remove_bg: |
|
|
background = np.zeros_like(frame) |
|
|
mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2) |
|
|
frame = frame * mask_3d + background * (1 - mask_3d) |
|
|
|
|
|
frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR) |
|
|
out.write(frame_bgr) |
|
|
|
|
|
out.release() |
|
|
|
|
|
|
|
|
if os.path.exists(output_path): |
|
|
return output_path, f"β
Success! Processed {len(video_segments)} frames" |
|
|
else: |
|
|
return None, f"β Error: Output file was not created" |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_details = traceback.format_exc() |
|
|
print(f"Error in segment_video_simple: {error_details}") |
|
|
return None, f"β Error: {str(e)}" |
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def segment_video_multi(video_file, objects_json, remove_bg=True): |
|
|
""" |
|
|
Multi-object video segmentation. |
|
|
|
|
|
Args: |
|
|
video_file: Video file path |
|
|
objects_json: JSON string with format: |
|
|
[ |
|
|
{"id": "ball", "init_frame": 0, "point_x": 360, "point_y": 640}, |
|
|
{"id": "player", "init_frame": 0, "point_x": 320, "point_y": 500} |
|
|
] |
|
|
remove_bg: If True, remove background. If False, overlay masks on original video. |
|
|
|
|
|
Returns: |
|
|
Tuple: (video1_path, video2_path, status_json) |
|
|
- video1_path: First object video |
|
|
- video2_path: Second object video (or None if only one object) |
|
|
- status_json: Status message with confidence scores in JSON format |
|
|
""" |
|
|
global device, model, processor |
|
|
|
|
|
if model is None: |
|
|
initialize_model() |
|
|
|
|
|
try: |
|
|
|
|
|
if video_file is None: |
|
|
return None, None, json.dumps({"error": "No video file provided"}) |
|
|
|
|
|
video_path = str(video_file) |
|
|
|
|
|
if not os.path.exists(video_path): |
|
|
return None, None, json.dumps({"error": f"Video file not found: {video_path}"}) |
|
|
|
|
|
|
|
|
try: |
|
|
objects = json.loads(objects_json) |
|
|
except json.JSONDecodeError as e: |
|
|
return None, None, json.dumps({"error": f"Invalid JSON format: {str(e)}"}) |
|
|
|
|
|
if not isinstance(objects, list) or len(objects) == 0: |
|
|
return None, None, json.dumps({"error": "objects must be a non-empty list"}) |
|
|
|
|
|
print(f"Processing {len(objects)} objects from video: {video_path}") |
|
|
|
|
|
|
|
|
video_frames, video_info = load_video_cv2(video_path) |
|
|
fps = video_info.get('fps', 30.0) |
|
|
first_frame = np.array(video_frames[0]) |
|
|
height, width = first_frame.shape[:2] |
|
|
|
|
|
|
|
|
results = [] |
|
|
output_files = [] |
|
|
|
|
|
for obj_idx, obj in enumerate(objects): |
|
|
obj_id = obj.get('id', f'object_{obj_idx}') |
|
|
init_frame = int(obj.get('init_frame', 0)) |
|
|
point_x = int(obj.get('point_x', width // 2)) |
|
|
point_y = int(obj.get('point_y', height // 2)) |
|
|
|
|
|
print(f"Processing object '{obj_id}' at frame {init_frame}, point ({point_x}, {point_y})") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
if obj_idx > 0: |
|
|
print(f" Reinitializing model for object {obj_idx+1}...") |
|
|
torch.cuda.empty_cache() |
|
|
model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=torch.float32) |
|
|
|
|
|
for param in model.parameters(): |
|
|
param.data = param.data.to(torch.float32) |
|
|
|
|
|
|
|
|
dtype = torch.float32 |
|
|
model.to(device, dtype=dtype) |
|
|
|
|
|
for param in model.parameters(): |
|
|
if param.dtype != torch.float32: |
|
|
param.data = param.data.to(torch.float32) |
|
|
|
|
|
|
|
|
inference_session = processor.init_video_session( |
|
|
video=video_frames, |
|
|
inference_device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
|
|
|
|
|
|
processor.add_inputs_to_inference_session( |
|
|
inference_session=inference_session, |
|
|
frame_idx=init_frame, |
|
|
obj_ids=1, |
|
|
input_points=[[[[point_x, point_y]]]], |
|
|
input_labels=[[[1]]], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_segments = {} |
|
|
confidence_scores = [] |
|
|
|
|
|
print(f" Using propagate_in_video_iterator for bidirectional propagation from frame {init_frame}...") |
|
|
|
|
|
for sam2_output in model.propagate_in_video_iterator(inference_session): |
|
|
video_res_masks = processor.post_process_masks( |
|
|
[sam2_output.pred_masks], |
|
|
original_sizes=[[inference_session.video_height, inference_session.video_width]], |
|
|
binarize=False |
|
|
)[0] |
|
|
video_segments[sam2_output.frame_idx] = video_res_masks |
|
|
|
|
|
|
|
|
mask_float = video_res_masks.float() if video_res_masks.dtype == torch.bool else video_res_masks |
|
|
confidence_scores.append(float(mask_float.mean())) |
|
|
|
|
|
print(f" β
Got masks for {len(video_segments)} frames (init_frame was {init_frame})") |
|
|
|
|
|
|
|
|
frames_before = sum(1 for f in video_segments.keys() if f < init_frame) |
|
|
frames_after = sum(1 for f in video_segments.keys() if f >= init_frame) |
|
|
print(f" π Frames before init_frame: {frames_before}, after: {frames_after}") |
|
|
|
|
|
|
|
|
mean_confidence = np.mean(confidence_scores) if confidence_scores else 0.0 |
|
|
min_confidence = np.min(confidence_scores) if confidence_scores else 0.0 |
|
|
|
|
|
|
|
|
output_path = tempfile.mktemp(suffix=f"_{obj_id}.mp4") |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
for frame_idx, frame_pil in enumerate(video_frames): |
|
|
frame = np.array(frame_pil) |
|
|
|
|
|
if frame_idx in video_segments: |
|
|
mask = video_segments[frame_idx].cpu().numpy() |
|
|
|
|
|
if mask.ndim == 4: |
|
|
mask = mask[0] |
|
|
if mask.ndim == 3: |
|
|
mask = mask.max(axis=0) |
|
|
|
|
|
if mask.shape != (height, width): |
|
|
mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) |
|
|
|
|
|
mask_binary = (mask > 0.5).astype(np.uint8) |
|
|
mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2) |
|
|
|
|
|
if remove_bg: |
|
|
|
|
|
background = np.zeros_like(frame) |
|
|
frame = frame * mask_3d + background * (1 - mask_3d) |
|
|
else: |
|
|
|
|
|
|
|
|
overlay_color = np.array([0, 255, 0], dtype=np.uint8) |
|
|
overlay = np.ones_like(frame) * overlay_color |
|
|
alpha = 0.5 |
|
|
frame = (frame * (1 - alpha * mask_3d) + overlay * alpha * mask_3d).astype(np.uint8) |
|
|
|
|
|
frame_bgr = cv2.cvtColor(frame.astype(np.uint8), cv2.COLOR_RGB2BGR) |
|
|
out.write(frame_bgr) |
|
|
|
|
|
out.release() |
|
|
|
|
|
if os.path.exists(output_path): |
|
|
output_files.append(output_path) |
|
|
results.append({ |
|
|
"object_id": obj_id, |
|
|
"status": "success", |
|
|
"frames_processed": len(video_segments), |
|
|
"tracking_confidence": { |
|
|
"mean": round(mean_confidence, 3), |
|
|
"min": round(min_confidence, 3) |
|
|
} |
|
|
}) |
|
|
else: |
|
|
results.append({ |
|
|
"object_id": obj_id, |
|
|
"status": "error", |
|
|
"error": "Output file not created" |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
results.append({ |
|
|
"object_id": obj_id, |
|
|
"status": "error", |
|
|
"error": str(e) |
|
|
}) |
|
|
print(f"Error processing object '{obj_id}': {e}") |
|
|
|
|
|
|
|
|
status_json = json.dumps({ |
|
|
"status": "completed", |
|
|
"objects_processed": len(results), |
|
|
"results": results |
|
|
}, indent=2) |
|
|
|
|
|
|
|
|
video1 = output_files[0] if len(output_files) > 0 else None |
|
|
video2 = output_files[1] if len(output_files) > 1 else None |
|
|
|
|
|
return video1, video2, status_json |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_details = traceback.format_exc() |
|
|
print(f"Error in segment_video_multi: {error_details}") |
|
|
return None, None, json.dumps({"error": str(e), "details": error_details}) |
|
|
|
|
|
|
|
|
|
|
|
def create_app(): |
|
|
initialize_model() |
|
|
|
|
|
with gr.Blocks(title="SAM2 Video Background Remover") as app: |
|
|
gr.Markdown(""" |
|
|
# π₯ SAM2 Video Background Remover |
|
|
|
|
|
Remove backgrounds from videos by tracking objects with Meta's SAM2. |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("Single Point"): |
|
|
gr.Markdown(""" |
|
|
**Track one object at a time.** |
|
|
|
|
|
1. Upload a video |
|
|
2. Enter X, Y coordinates of the object to track |
|
|
3. Click "Process Video" |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
video_input_simple = gr.File( |
|
|
label="Upload Video", |
|
|
file_types=["video"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
point_x = gr.Textbox(label="Point X", value="320") |
|
|
point_y = gr.Textbox(label="Point Y", value="240") |
|
|
|
|
|
frame_idx = gr.Textbox(label="Frame Index", value="0") |
|
|
remove_bg = gr.Checkbox(label="Remove Background", value=True) |
|
|
|
|
|
process_btn_simple = gr.Button("π¬ Process Video", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_video_simple = gr.File(label="Output Video") |
|
|
status_text_simple = gr.Textbox(label="Status", lines=3) |
|
|
|
|
|
process_btn_simple.click( |
|
|
fn=segment_video_simple, |
|
|
inputs=[video_input_simple, point_x, point_y, frame_idx, remove_bg], |
|
|
outputs=[output_video_simple, status_text_simple] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("Multi-Point (Advanced)"): |
|
|
gr.Markdown(""" |
|
|
**Track one object with MULTIPLE points across frames.** |
|
|
|
|
|
Perfect for objects that move fast or get occluded (like soccer balls). |
|
|
Re-initialize tracking at different frames to maintain accuracy. |
|
|
|
|
|
1. Upload a video |
|
|
2. Provide JSON with multiple tracking points |
|
|
3. Click "Process Advanced" |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
video_input_advanced = gr.File( |
|
|
label="Upload Video", |
|
|
file_types=["video"] |
|
|
) |
|
|
|
|
|
points_advanced_json = gr.Textbox( |
|
|
label="Tracking Configuration (JSON)", |
|
|
lines=15, |
|
|
value=json.dumps({ |
|
|
"object_id": "ball", |
|
|
"tracking_points": [ |
|
|
{"frame": 0, "x": 360, "y": 900}, |
|
|
{"frame": 50, "x": 380, "y": 850}, |
|
|
{"frame": 100, "x": 400, "y": 800} |
|
|
], |
|
|
"remove_bg": True |
|
|
}, indent=2), |
|
|
placeholder='See example for format' |
|
|
) |
|
|
|
|
|
process_btn_advanced = gr.Button("π¬ Process Advanced", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_video_advanced = gr.File(label="Output Video") |
|
|
status_text_advanced = gr.Textbox(label="Status", lines=8) |
|
|
|
|
|
process_btn_advanced.click( |
|
|
fn=segment_video_advanced, |
|
|
inputs=[video_input_advanced, points_advanced_json], |
|
|
outputs=[output_video_advanced, status_text_advanced] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### π Multi-Point Tracking Explained |
|
|
|
|
|
**When to use this:** |
|
|
- Ball goes off-screen and comes back |
|
|
- Ball gets occluded by player |
|
|
- Fast motion causes tracking to drift |
|
|
- Small object is hard to track continuously |
|
|
|
|
|
**How it works:** |
|
|
1. Mark the ball at frame 0 where it's stationary |
|
|
2. Mark again at frame 50 where it's visible after kick |
|
|
3. Mark again at frame 100 where it lands |
|
|
4. SAM2 propagates between these anchor points |
|
|
|
|
|
**Result:** Much more robust tracking across entire video! |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Tab("Multiple Objects"): |
|
|
gr.Markdown(""" |
|
|
**Track multiple objects in one video (e.g., ball AND player).** |
|
|
|
|
|
1. Upload a video |
|
|
2. Provide objects JSON (default values are for testing with `original.mp4`) |
|
|
3. Click "Process Multi-Object" |
|
|
|
|
|
You'll receive separate video files for each tracked object! |
|
|
|
|
|
**β οΈ TESTING NOTE**: Order swapped - PLAYER FIRST (init_frame=165), BALL SECOND (init_frame=0) to test if propagation works regardless of init_frame order. |
|
|
|
|
|
**Note**: Default JSON values are optimized for the test video `original.mp4` (portrait soccer video). |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
video_input_multi = gr.File( |
|
|
label="Upload Video", |
|
|
file_types=["video"] |
|
|
) |
|
|
|
|
|
objects_json = gr.Textbox( |
|
|
label="Objects JSON (PLAYER FIRST for testing!)", |
|
|
lines=12, |
|
|
value=json.dumps([ |
|
|
{"id": "player", "init_frame": 165, "point_x": 180, "point_y": 320}, |
|
|
{"id": "ball", "init_frame": 0, "point_x": 360, "point_y": 640} |
|
|
], indent=2), |
|
|
placeholder='[{"id": "player", "init_frame": 165, "point_x": 180, "point_y": 320}]' |
|
|
) |
|
|
|
|
|
remove_bg_multi = gr.Checkbox( |
|
|
label="Remove Background", |
|
|
value=True, |
|
|
info="If checked, shows only tracked objects. If unchecked, overlays masks on original video." |
|
|
) |
|
|
|
|
|
process_btn_multi = gr.Button("π¬ Process Multi-Object", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_video_multi_1 = gr.File(label="Output Video 1") |
|
|
output_video_multi_2 = gr.File(label="Output Video 2") |
|
|
status_text_multi = gr.Textbox(label="Status (JSON)", lines=15) |
|
|
|
|
|
process_btn_multi.click( |
|
|
fn=segment_video_multi, |
|
|
inputs=[video_input_multi, objects_json, remove_bg_multi], |
|
|
outputs=[output_video_multi_1, output_video_multi_2, status_text_multi] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### π Coordinate System |
|
|
- **Coordinates (X, Y)**: Absolute pixel coordinates |
|
|
- **Origin**: Top-left corner (0, 0) |
|
|
- **Portrait video (720x1280)**: Center β X=360, Y=640 |
|
|
- **Landscape video (1920x1080)**: Center β X=960, Y=540 |
|
|
|
|
|
### π― Tips for Best Results |
|
|
- **Frame Index**: Choose a frame where the object is clearly visible and unoccluded |
|
|
- **Point Selection**: Click on a distinctive part of the object |
|
|
- **Multiple Points**: For complex shapes (like players), use multiple objects with different points |
|
|
- **Tracking Quality**: The API returns confidence scores for each object |
|
|
|
|
|
### πΉ Output Format |
|
|
- **Format**: MP4 (H.264) |
|
|
- **Background**: Black (RGB: 0, 0, 0) - can be made transparent in post-processing |
|
|
- **FPS**: Same as input video |
|
|
- **Resolution**: Same as input video |
|
|
|
|
|
### β‘ Performance |
|
|
- Processing time: ~30-60 seconds for typical videos on GPU |
|
|
- Videos are processed sequentially (one at a time in multi-object mode) |
|
|
- GPU acceleration via ZeroGPU |
|
|
""") |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app = create_app() |
|
|
app.launch(share=True, show_error=True) |
|
|
|
|
|
|