Mirko Trasciatti commited on
Commit
3cddaf8
·
1 Parent(s): 69f4f56

Deploy SAM2 Video Background Remover with Gradio UI and API

Browse files
Files changed (5) hide show
  1. .gitignore +64 -0
  2. README.md +233 -7
  3. api_example.py +356 -0
  4. app.py +540 -0
  5. requirements.txt +8 -0
.gitignore ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ .venv
25
+ venv/
26
+ ENV/
27
+ env/
28
+
29
+ # IDEs
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+ # Gradio
37
+ flagged/
38
+ gradio_cached_examples/
39
+
40
+ # Temporary files
41
+ *.tmp
42
+ *.temp
43
+ tmp/
44
+ temp/
45
+
46
+ # Video files (if you don't want to commit test videos)
47
+ *.mp4
48
+ *.avi
49
+ *.mov
50
+ *.mkv
51
+ !example_*.mp4
52
+
53
+ # Model cache
54
+ .cache/
55
+ models/
56
+
57
+ # OS
58
+ .DS_Store
59
+ Thumbs.db
60
+
61
+ # Logs
62
+ *.log
63
+ logs/
64
+
README.md CHANGED
@@ -1,14 +1,240 @@
1
  ---
2
- title: Chaskick
3
- emoji: 🦀
4
- colorFrom: purple
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: BKGRMV
 
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: SAM2 Video Background Remover
3
+ emoji: 🎥
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ tags:
12
+ - computer-vision
13
+ - video
14
+ - segmentation
15
+ - sam2
16
+ - background-removal
17
+ - object-tracking
18
  ---
19
 
20
+ # 🎥 SAM2 Video Background Remover
21
+
22
+ Remove backgrounds from videos by tracking objects using Meta's **Segment Anything Model 2 (SAM2)**.
23
+
24
+ ## Features
25
+
26
+ ✨ **Background Removal**: Automatically remove backgrounds and keep only tracked objects
27
+ 🎯 **Object Tracking**: Track multiple objects across video frames
28
+ 🖥️ **Interactive UI**: Easy-to-use Gradio interface
29
+ 🔌 **REST API**: Programmatic access via API endpoints
30
+ ⚡ **GPU Accelerated**: Fast processing with CUDA support
31
+
32
+ ## How It Works
33
+
34
+ SAM2 is a foundation model for video segmentation that can:
35
+ 1. **Segment objects** based on point or box annotations
36
+ 2. **Track objects** automatically across all video frames
37
+ 3. **Handle occlusions** and object reappearance
38
+ 4. **Process multiple objects** simultaneously
39
+
40
+ ## Usage
41
+
42
+ ### 🖱️ Simple Mode (Web UI)
43
+
44
+ 1. Upload your video
45
+ 2. Specify X,Y coordinates of the object you want to track (from first frame)
46
+ 3. Click "Process Video"
47
+ 4. Download the result with background removed!
48
+
49
+ **Example**: For a 640x480 video with a person in the center, use X=320, Y=240
50
+
51
+ ### 🔧 Advanced Mode (JSON Annotations)
52
+
53
+ For more control, use JSON annotations:
54
+
55
+ ```json
56
+ [
57
+ {
58
+ "frame_idx": 0,
59
+ "object_id": 1,
60
+ "points": [[320, 240]],
61
+ "labels": [1]
62
+ }
63
+ ]
64
+ ```
65
+
66
+ **Parameters**:
67
+ - `frame_idx`: Frame number to annotate (0 = first frame)
68
+ - `object_id`: Unique ID for each object (1, 2, 3, ...)
69
+ - `points`: List of [x, y] coordinates on the object
70
+ - `labels`: `1` for foreground point, `0` for background point
71
+
72
+ ### 📡 API Usage
73
+
74
+ You can call this Space programmatically using the Gradio Client:
75
+
76
+ #### Python Example
77
+
78
+ ```python
79
+ from gradio_client import Client
80
+ import json
81
+
82
+ # Connect to the Space
83
+ client = Client("YOUR_USERNAME/sam2-video-bg-remover")
84
+
85
+ # Define what to track
86
+ annotations = [
87
+ {
88
+ "frame_idx": 0,
89
+ "object_id": 1,
90
+ "points": [[320, 240]], # x, y coordinates
91
+ "labels": [1] # 1 = foreground
92
+ }
93
+ ]
94
+
95
+ # Process video
96
+ result = client.predict(
97
+ video_file="./input_video.mp4",
98
+ annotations_json=json.dumps(annotations),
99
+ remove_background=True,
100
+ max_frames=300, # Limit frames for faster processing
101
+ api_name="/segment_video_api"
102
+ )
103
+
104
+ print(f"Output video saved to: {result}")
105
+ ```
106
+
107
+ #### Track Multiple Objects
108
+
109
+ ```python
110
+ annotations = [
111
+ # First object (person)
112
+ {
113
+ "frame_idx": 0,
114
+ "object_id": 1,
115
+ "points": [[320, 240]],
116
+ "labels": [1]
117
+ },
118
+ # Second object (ball)
119
+ {
120
+ "frame_idx": 0,
121
+ "object_id": 2,
122
+ "points": [[500, 300]],
123
+ "labels": [1]
124
+ }
125
+ ]
126
+ ```
127
+
128
+ #### Refine Segmentation with Background Points
129
+
130
+ ```python
131
+ annotations = [
132
+ {
133
+ "frame_idx": 0,
134
+ "object_id": 1,
135
+ "points": [
136
+ [320, 240], # Point ON the object
137
+ [100, 100] # Point on background to exclude
138
+ ],
139
+ "labels": [1, 0] # 1=foreground, 0=background
140
+ }
141
+ ]
142
+ ```
143
+
144
+ ### 🌐 HTTP API
145
+
146
+ You can also call the API directly via HTTP:
147
+
148
+ ```bash
149
+ curl -X POST https://YOUR_USERNAME-sam2-video-bg-remover.hf.space/api/predict \
150
+ -F "video_file=@input_video.mp4" \
151
+ -F 'annotations_json=[{"frame_idx":0,"object_id":1,"points":[[320,240]],"labels":[1]}]' \
152
+ -F "remove_background=true" \
153
+ -F "max_frames=300"
154
+ ```
155
+
156
+ ## Parameters
157
+
158
+ | Parameter | Type | Default | Description |
159
+ |-----------|------|---------|-------------|
160
+ | `video_file` | File | - | Input video file (required) |
161
+ | `annotations_json` | String | - | JSON array of annotations (required) |
162
+ | `remove_background` | Boolean | `true` | Remove background or just highlight objects |
163
+ | `max_frames` | Integer | `null` | Limit frames for faster processing |
164
+
165
+ ## Tips & Best Practices
166
+
167
+ ### 🎯 Getting Good Results
168
+
169
+ 1. **Choose Clear Points**: Click on the center/most distinctive part of your object
170
+ 2. **Add Multiple Points**: For complex objects, add 2-3 points on different parts
171
+ 3. **Use Background Points**: Add points with `label: 0` on areas you DON'T want
172
+ 4. **Annotate Key Frames**: If object changes significantly, add annotations on multiple frames
173
+
174
+ ### ⚡ Performance Tips
175
+
176
+ 1. **Limit Frames**: Use `max_frames` parameter for long videos
177
+ 2. **Use Smaller Model**: Default is `sam2.1-hiera-tiny` for speed
178
+ 3. **Process Shorter Clips**: Split long videos into segments
179
+
180
+ ### 🐛 Troubleshooting
181
+
182
+ | Issue | Solution |
183
+ |-------|----------|
184
+ | Object not tracked | Add more points on different parts of the object |
185
+ | Background leakage | Add background points with `label: 0` |
186
+ | Slow processing | Reduce `max_frames` or use a shorter video |
187
+ | Wrong object tracked | Be more precise with point coordinates |
188
+
189
+ ## Model Information
190
+
191
+ This Space uses **facebook/sam2.1-hiera-tiny** for efficient processing. Other available models:
192
+
193
+ - `facebook/sam2.1-hiera-tiny` - Fastest, good quality ⚡
194
+ - `facebook/sam2.1-hiera-small` - Balanced
195
+ - `facebook/sam2.1-hiera-base-plus` - Higher quality
196
+ - `facebook/sam2.1-hiera-large` - Best quality, slower 🎯
197
+
198
+ ## Use Cases
199
+
200
+ - 🎬 **Video Production**: Remove backgrounds for green screen effects
201
+ - 🏃 **Sports Analysis**: Isolate athletes for motion analysis
202
+ - 🎮 **Content Creation**: Extract game characters or objects
203
+ - 🔬 **Research**: Track objects in scientific videos
204
+ - 📱 **Social Media**: Create engaging content with background removal
205
+
206
+ ## Limitations
207
+
208
+ - Video length affects processing time (longer = slower)
209
+ - GPU recommended for videos > 10 seconds
210
+ - Very fast-moving objects may require multiple annotations
211
+ - Extreme lighting changes can affect tracking quality
212
+
213
+ ## Citation
214
+
215
+ If you use this Space, please cite the SAM2 paper:
216
+
217
+ ```bibtex
218
+ @article{ravi2024sam2,
219
+ title={Segment Anything in Images and Videos},
220
+ author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and others},
221
+ journal={arXiv preprint arXiv:2408.00714},
222
+ year={2024}
223
+ }
224
+ ```
225
+
226
+ ## License
227
+
228
+ Apache 2.0
229
+
230
+ ## Links
231
+
232
+ - 📚 [SAM2 Documentation](https://huggingface.co/docs/transformers/model_doc/sam2_video)
233
+ - 🤗 [Model on Hugging Face](https://huggingface.co/facebook/sam2.1-hiera-tiny)
234
+ - 📄 [Research Paper](https://arxiv.org/abs/2408.00714)
235
+ - 💻 [Original Repository](https://github.com/facebookresearch/segment-anything-2)
236
+
237
+ ---
238
+
239
+ Built with ❤️ using [Transformers](https://github.com/huggingface/transformers) and [Gradio](https://gradio.app)
240
+
api_example.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example script showing how to use the SAM2 Video Background Remover API.
3
+
4
+ This script demonstrates various use cases:
5
+ 1. Simple single object tracking
6
+ 2. Multiple object tracking
7
+ 3. Refined segmentation with background points
8
+ 4. Batch processing multiple videos
9
+ """
10
+
11
+ from gradio_client import Client
12
+ import json
13
+ from pathlib import Path
14
+
15
+
16
+ def example_1_simple_tracking():
17
+ """
18
+ Example 1: Track a single object (e.g., person, ball, car)
19
+ """
20
+ print("=" * 60)
21
+ print("Example 1: Simple Single Object Tracking")
22
+ print("=" * 60)
23
+
24
+ # Connect to your Space
25
+ client = Client("furbola/chaskick")
26
+
27
+ # Simple annotation: click on the center of your object in the first frame
28
+ annotations = [
29
+ {
30
+ "frame_idx": 0, # First frame
31
+ "object_id": 1, # First object
32
+ "points": [[320, 240]], # x, y coordinates of the object center
33
+ "labels": [1] # 1 = this is a foreground point
34
+ }
35
+ ]
36
+
37
+ # Process the video
38
+ result = client.predict(
39
+ video_file="./input_video.mp4",
40
+ annotations_json=json.dumps(annotations),
41
+ remove_background=True,
42
+ max_frames=None, # Process all frames
43
+ api_name="/segment_video_api"
44
+ )
45
+
46
+ print(f"✅ Output saved to: {result}")
47
+
48
+
49
+ def example_2_multi_object_tracking():
50
+ """
51
+ Example 2: Track multiple objects simultaneously
52
+ Useful for: tracking player + ball, multiple people, etc.
53
+ """
54
+ print("\n" + "=" * 60)
55
+ print("Example 2: Multi-Object Tracking")
56
+ print("=" * 60)
57
+
58
+ client = Client("furbola/chaskick")
59
+
60
+ annotations = [
61
+ # Object 1: Player
62
+ {
63
+ "frame_idx": 0,
64
+ "object_id": 1,
65
+ "points": [[320, 240]],
66
+ "labels": [1]
67
+ },
68
+ # Object 2: Ball
69
+ {
70
+ "frame_idx": 0,
71
+ "object_id": 2,
72
+ "points": [[500, 300]],
73
+ "labels": [1]
74
+ },
75
+ # Object 3: Another player
76
+ {
77
+ "frame_idx": 0,
78
+ "object_id": 3,
79
+ "points": [[150, 200]],
80
+ "labels": [1]
81
+ }
82
+ ]
83
+
84
+ result = client.predict(
85
+ video_file="./soccer_match.mp4",
86
+ annotations_json=json.dumps(annotations),
87
+ remove_background=True,
88
+ max_frames=300, # Limit to 300 frames for speed
89
+ api_name="/segment_video_api"
90
+ )
91
+
92
+ print(f"✅ Tracked 3 objects! Output: {result}")
93
+
94
+
95
+ def example_3_refined_segmentation():
96
+ """
97
+ Example 3: Use both foreground AND background points for better accuracy
98
+ Useful when: object is complex, background is similar color, etc.
99
+ """
100
+ print("\n" + "=" * 60)
101
+ print("Example 3: Refined Segmentation with Negative Points")
102
+ print("=" * 60)
103
+
104
+ client = Client("furbola/chaskick")
105
+
106
+ annotations = [
107
+ {
108
+ "frame_idx": 0,
109
+ "object_id": 1,
110
+ "points": [
111
+ [320, 240], # ✅ Point ON the person's body
112
+ [350, 250], # ✅ Another point on the person
113
+ [280, 220], # ✅ Third point for better coverage
114
+ [100, 100], # ❌ Point on the BACKGROUND to exclude
115
+ [600, 400] # ❌ Another background point
116
+ ],
117
+ "labels": [
118
+ 1, # foreground
119
+ 1, # foreground
120
+ 1, # foreground
121
+ 0, # background (exclude this area)
122
+ 0 # background (exclude this area)
123
+ ]
124
+ }
125
+ ]
126
+
127
+ result = client.predict(
128
+ video_file="./person_video.mp4",
129
+ annotations_json=json.dumps(annotations),
130
+ remove_background=True,
131
+ max_frames=None,
132
+ api_name="/segment_video_api"
133
+ )
134
+
135
+ print(f"✅ Refined segmentation complete: {result}")
136
+
137
+
138
+ def example_4_temporal_annotations():
139
+ """
140
+ Example 4: Add annotations on multiple frames
141
+ Useful when: object changes appearance, camera cuts, occlusions
142
+ """
143
+ print("\n" + "=" * 60)
144
+ print("Example 4: Multi-Frame Annotations")
145
+ print("=" * 60)
146
+
147
+ client = Client("furbola/chaskick")
148
+
149
+ annotations = [
150
+ # Annotate frame 0
151
+ {
152
+ "frame_idx": 0,
153
+ "object_id": 1,
154
+ "points": [[320, 240]],
155
+ "labels": [1]
156
+ },
157
+ # Annotate frame 50 (object might have moved or changed)
158
+ {
159
+ "frame_idx": 50,
160
+ "object_id": 1,
161
+ "points": [[450, 300]],
162
+ "labels": [1]
163
+ },
164
+ # Annotate frame 100 (after a camera cut or scene change)
165
+ {
166
+ "frame_idx": 100,
167
+ "object_id": 1,
168
+ "points": [[200, 180]],
169
+ "labels": [1]
170
+ }
171
+ ]
172
+
173
+ result = client.predict(
174
+ video_file="./long_video.mp4",
175
+ annotations_json=json.dumps(annotations),
176
+ remove_background=True,
177
+ max_frames=None,
178
+ api_name="/segment_video_api"
179
+ )
180
+
181
+ print(f"✅ Multi-frame tracking complete: {result}")
182
+
183
+
184
+ def example_5_batch_processing():
185
+ """
186
+ Example 5: Process multiple videos in batch
187
+ """
188
+ print("\n" + "=" * 60)
189
+ print("Example 5: Batch Processing Multiple Videos")
190
+ print("=" * 60)
191
+
192
+ client = Client("furbola/chaskick")
193
+
194
+ # List of videos to process
195
+ videos = [
196
+ {"path": "./video1.mp4", "point": [320, 240]},
197
+ {"path": "./video2.mp4", "point": [400, 300]},
198
+ {"path": "./video3.mp4", "point": [250, 200]},
199
+ ]
200
+
201
+ results = []
202
+
203
+ for i, video in enumerate(videos, 1):
204
+ print(f"\nProcessing video {i}/{len(videos)}: {video['path']}")
205
+
206
+ annotations = [{
207
+ "frame_idx": 0,
208
+ "object_id": 1,
209
+ "points": [video['point']],
210
+ "labels": [1]
211
+ }]
212
+
213
+ try:
214
+ result = client.predict(
215
+ video_file=video['path'],
216
+ annotations_json=json.dumps(annotations),
217
+ remove_background=True,
218
+ max_frames=200, # Limit frames for faster batch processing
219
+ api_name="/segment_video_api"
220
+ )
221
+ results.append({"input": video['path'], "output": result, "status": "✅"})
222
+ print(f" ✅ Success: {result}")
223
+ except Exception as e:
224
+ results.append({"input": video['path'], "output": None, "status": f"❌ {str(e)}"})
225
+ print(f" ❌ Failed: {e}")
226
+
227
+ print("\n" + "=" * 60)
228
+ print("Batch Processing Summary:")
229
+ print("=" * 60)
230
+ for r in results:
231
+ print(f"{r['status']} {r['input']} -> {r['output']}")
232
+
233
+
234
+ def example_6_highlight_mode():
235
+ """
236
+ Example 6: Highlight objects instead of removing background
237
+ Useful for: visualization, debugging, object detection demos
238
+ """
239
+ print("\n" + "=" * 60)
240
+ print("Example 6: Highlight Mode (Keep Background)")
241
+ print("=" * 60)
242
+
243
+ client = Client("furbola/chaskick")
244
+
245
+ annotations = [{
246
+ "frame_idx": 0,
247
+ "object_id": 1,
248
+ "points": [[320, 240]],
249
+ "labels": [1]
250
+ }]
251
+
252
+ result = client.predict(
253
+ video_file="./input_video.mp4",
254
+ annotations_json=json.dumps(annotations),
255
+ remove_background=False, # Keep background, just highlight the object
256
+ max_frames=None,
257
+ api_name="/segment_video_api"
258
+ )
259
+
260
+ print(f"✅ Object highlighted: {result}")
261
+
262
+
263
+ def example_7_find_coordinates():
264
+ """
265
+ Example 7: Helper to find coordinates in a video
266
+ Opens the first frame so you can identify x,y coordinates
267
+ """
268
+ print("\n" + "=" * 60)
269
+ print("Example 7: Find Coordinates Helper")
270
+ print("=" * 60)
271
+
272
+ import cv2
273
+
274
+ video_path = "./input_video.mp4"
275
+
276
+ # Read first frame
277
+ cap = cv2.VideoCapture(video_path)
278
+ ret, frame = cap.read()
279
+ cap.release()
280
+
281
+ if ret:
282
+ # Save first frame
283
+ cv2.imwrite("first_frame.jpg", frame)
284
+ print(f"✅ Saved first frame to: first_frame.jpg")
285
+ print(f" Video size: {frame.shape[1]}x{frame.shape[0]} (width x height)")
286
+ print(f" Open this image and note the x,y coordinates of your object")
287
+ print(f" Then use those coordinates in your annotation!")
288
+ else:
289
+ print("❌ Could not read video")
290
+
291
+
292
+ # ============================================================================
293
+ # UTILITY FUNCTIONS
294
+ # ============================================================================
295
+
296
+ def create_annotation(frame_idx, object_id, points, labels=None):
297
+ """
298
+ Helper function to create annotation objects.
299
+
300
+ Args:
301
+ frame_idx: Frame number (0 = first frame)
302
+ object_id: Unique object ID (1, 2, 3, ...)
303
+ points: List of [x, y] coordinates, e.g., [[320, 240]]
304
+ labels: List of labels (1=foreground, 0=background). Defaults to all 1s.
305
+
306
+ Returns:
307
+ Dictionary with annotation
308
+ """
309
+ if labels is None:
310
+ labels = [1] * len(points)
311
+
312
+ return {
313
+ "frame_idx": frame_idx,
314
+ "object_id": object_id,
315
+ "points": points,
316
+ "labels": labels
317
+ }
318
+
319
+
320
+ def load_annotations_from_file(json_file):
321
+ """Load annotations from a JSON file."""
322
+ with open(json_file, 'r') as f:
323
+ return json.load(f)
324
+
325
+
326
+ def save_annotations_to_file(annotations, json_file):
327
+ """Save annotations to a JSON file."""
328
+ with open(json_file, 'w') as f:
329
+ json.dump(annotations, f, indent=2)
330
+
331
+
332
+ # ============================================================================
333
+ # MAIN
334
+ # ============================================================================
335
+
336
+ if __name__ == "__main__":
337
+ print("""
338
+ ╔════════════════════════════════════════════════════════════╗
339
+ ║ SAM2 Video Background Remover - API Examples ║
340
+ ║ Choose an example to run or uncomment in the code ║
341
+ ╚════════════════════════════════════════════════════════════╝
342
+ """)
343
+
344
+ # Uncomment the examples you want to run:
345
+
346
+ # example_1_simple_tracking()
347
+ # example_2_multi_object_tracking()
348
+ # example_3_refined_segmentation()
349
+ # example_4_temporal_annotations()
350
+ # example_5_batch_processing()
351
+ # example_6_highlight_mode()
352
+ # example_7_find_coordinates()
353
+
354
+ print("\n✅ Done! Check the output files.")
355
+ print("\n🎉 Your Space: https://huggingface.co/spaces/furbola/chaskick")
356
+
app.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM2 Video Segmentation Space
3
+ Removes background from videos by tracking specified objects.
4
+ Provides both Gradio UI and API endpoints.
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import numpy as np
10
+ import cv2
11
+ import tempfile
12
+ import os
13
+ from pathlib import Path
14
+ from typing import List, Tuple, Optional, Dict, Any
15
+ from transformers import Sam2VideoModel, Sam2VideoProcessor
16
+ from transformers.video_utils import load_video
17
+ from PIL import Image
18
+ import json
19
+
20
+ # Global model variables
21
+ MODEL_NAME = "facebook/sam2.1-hiera-tiny" # Options: tiny, small, base-plus, large
22
+ device = None
23
+ model = None
24
+ processor = None
25
+
26
+
27
+ def initialize_model():
28
+ """Initialize SAM2 model and processor."""
29
+ global device, model, processor
30
+
31
+ # Determine device
32
+ if torch.cuda.is_available():
33
+ device = torch.device("cuda")
34
+ dtype = torch.float16
35
+ elif torch.backends.mps.is_available():
36
+ device = torch.device("mps")
37
+ dtype = torch.float32
38
+ else:
39
+ device = torch.device("cpu")
40
+ dtype = torch.float32
41
+
42
+ print(f"Loading SAM2 model on {device}...")
43
+
44
+ # Load model and processor
45
+ model = Sam2VideoModel.from_pretrained(MODEL_NAME).to(device, dtype=dtype)
46
+ processor = Sam2VideoProcessor.from_pretrained(MODEL_NAME)
47
+
48
+ print("Model loaded successfully!")
49
+ return device, model, processor
50
+
51
+
52
+ def extract_frames_from_video(video_path: str, max_frames: Optional[int] = None) -> Tuple[List[Image.Image], Dict]:
53
+ """Extract frames from video file."""
54
+ video_frames, info = load_video(video_path)
55
+
56
+ if max_frames and len(video_frames) > max_frames:
57
+ # Sample frames uniformly
58
+ indices = np.linspace(0, len(video_frames) - 1, max_frames, dtype=int)
59
+ video_frames = [video_frames[i] for i in indices]
60
+
61
+ return video_frames, info
62
+
63
+
64
+ def create_output_video(
65
+ video_frames: List[Image.Image],
66
+ masks: Dict[int, torch.Tensor],
67
+ output_path: str,
68
+ fps: float = 30.0,
69
+ remove_background: bool = True
70
+ ) -> str:
71
+ """
72
+ Create output video with segmented objects.
73
+
74
+ Args:
75
+ video_frames: Original video frames
76
+ masks: Dictionary mapping frame_idx to mask tensors
77
+ output_path: Path to save output video
78
+ fps: Frames per second
79
+ remove_background: If True, remove background; if False, highlight objects
80
+ """
81
+ if not masks:
82
+ raise ValueError("No masks provided")
83
+
84
+ # Get first frame to determine dimensions
85
+ first_frame = np.array(video_frames[0])
86
+ height, width = first_frame.shape[:2]
87
+
88
+ # Initialize video writer
89
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
90
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
91
+
92
+ for frame_idx, frame_pil in enumerate(video_frames):
93
+ frame = np.array(frame_pil)
94
+
95
+ if frame_idx in masks:
96
+ mask = masks[frame_idx].cpu().numpy()
97
+
98
+ # Handle different mask shapes
99
+ if mask.ndim == 4: # (batch, num_objects, H, W)
100
+ mask = mask[0] # Take first batch
101
+ if mask.ndim == 3: # (num_objects, H, W)
102
+ # Combine all object masks
103
+ mask = mask.max(axis=0)
104
+
105
+ # Resize mask to frame size if needed
106
+ if mask.shape != (height, width):
107
+ mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
108
+
109
+ # Convert to binary mask
110
+ mask_binary = (mask > 0.5).astype(np.uint8)
111
+
112
+ if remove_background:
113
+ # Keep only the tracked objects (remove background)
114
+ if frame.shape[2] == 3: # RGB
115
+ # Create RGBA with alpha channel
116
+ result = np.zeros((height, width, 4), dtype=np.uint8)
117
+ result[:, :, :3] = frame
118
+ result[:, :, 3] = mask_binary * 255
119
+
120
+ # Convert back to RGB with black background
121
+ background = np.zeros_like(frame)
122
+ mask_3d = np.repeat(mask_binary[:, :, np.newaxis], 3, axis=2)
123
+ result_rgb = frame * mask_3d + background * (1 - mask_3d)
124
+ frame = result_rgb.astype(np.uint8)
125
+ else:
126
+ # Highlight tracked objects (overlay colored mask)
127
+ overlay = frame.copy()
128
+ overlay[mask_binary > 0] = [0, 255, 0] # Green overlay
129
+ frame = cv2.addWeighted(frame, 0.7, overlay, 0.3, 0)
130
+
131
+ # Convert RGB to BGR for OpenCV
132
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
133
+ out.write(frame_bgr)
134
+
135
+ out.release()
136
+ return output_path
137
+
138
+
139
+ def segment_video(
140
+ video_path: str,
141
+ annotations: List[Dict[str, Any]],
142
+ remove_background: bool = True,
143
+ max_frames: Optional[int] = None
144
+ ) -> str:
145
+ """
146
+ Main function to segment video based on annotations.
147
+
148
+ Args:
149
+ video_path: Path to input video
150
+ annotations: List of annotation dictionaries with format:
151
+ [
152
+ {
153
+ "frame_idx": 0,
154
+ "object_id": 1,
155
+ "points": [[x1, y1], [x2, y2], ...],
156
+ "labels": [1, 1, ...] # 1 for foreground, 0 for background
157
+ },
158
+ ...
159
+ ]
160
+ remove_background: If True, remove background; if False, highlight objects
161
+ max_frames: Maximum number of frames to process (None = all frames)
162
+
163
+ Returns:
164
+ Path to output video file
165
+ """
166
+ global device, model, processor
167
+
168
+ if model is None:
169
+ initialize_model()
170
+
171
+ # Load video frames
172
+ print("Loading video frames...")
173
+ video_frames, video_info = extract_frames_from_video(video_path, max_frames)
174
+ fps = video_info.get('fps', 30.0)
175
+
176
+ print(f"Processing {len(video_frames)} frames at {fps} FPS")
177
+
178
+ # Initialize inference session
179
+ dtype = torch.float16 if device.type == "cuda" else torch.float32
180
+ inference_session = processor.init_video_session(
181
+ video=video_frames,
182
+ inference_device=device,
183
+ dtype=dtype,
184
+ )
185
+
186
+ # Add annotations to inference session
187
+ print("Adding annotations...")
188
+ for ann in annotations:
189
+ frame_idx = ann["frame_idx"]
190
+ obj_id = ann["object_id"]
191
+ points = ann.get("points", [])
192
+ labels = ann.get("labels", [1] * len(points))
193
+
194
+ if points:
195
+ # Format points for processor: [[[[x, y], [x, y], ...]]]
196
+ formatted_points = [[points]]
197
+ formatted_labels = [[labels]]
198
+
199
+ processor.add_inputs_to_inference_session(
200
+ inference_session=inference_session,
201
+ frame_idx=frame_idx,
202
+ obj_ids=obj_id,
203
+ input_points=formatted_points,
204
+ input_labels=formatted_labels,
205
+ )
206
+
207
+ # Run inference on this frame
208
+ outputs = model(
209
+ inference_session=inference_session,
210
+ frame_idx=frame_idx,
211
+ )
212
+
213
+ # Propagate through all frames
214
+ print("Propagating masks through video...")
215
+ video_segments = {}
216
+
217
+ for sam2_output in model.propagate_in_video_iterator(inference_session):
218
+ video_res_masks = processor.post_process_masks(
219
+ [sam2_output.pred_masks],
220
+ original_sizes=[[inference_session.video_height, inference_session.video_width]],
221
+ binarize=False
222
+ )[0]
223
+ video_segments[sam2_output.frame_idx] = video_res_masks
224
+
225
+ print(f"Generated masks for {len(video_segments)} frames")
226
+
227
+ # Create output video
228
+ output_path = tempfile.mktemp(suffix=".mp4")
229
+ print("Creating output video...")
230
+ create_output_video(video_frames, video_segments, output_path, fps, remove_background)
231
+
232
+ print(f"Output video saved to: {output_path}")
233
+ return output_path
234
+
235
+
236
+ # ============================================================================
237
+ # GRADIO INTERFACE
238
+ # ============================================================================
239
+
240
+ def gradio_segment_video(
241
+ video_file,
242
+ annotation_json: str,
243
+ remove_bg: bool = True,
244
+ max_frames: Optional[int] = None
245
+ ):
246
+ """
247
+ Gradio wrapper for video segmentation.
248
+
249
+ Args:
250
+ video_file: Uploaded video file
251
+ annotation_json: JSON string with annotations
252
+ remove_bg: Whether to remove background
253
+ max_frames: Maximum frames to process
254
+ """
255
+ try:
256
+ # Parse annotations
257
+ annotations = json.loads(annotation_json)
258
+
259
+ if not isinstance(annotations, list):
260
+ return None, "Error: Annotations must be a list of objects"
261
+
262
+ # Process video
263
+ output_path = segment_video(
264
+ video_path=video_file,
265
+ annotations=annotations,
266
+ remove_background=remove_bg,
267
+ max_frames=max_frames
268
+ )
269
+
270
+ return output_path, "✅ Video processed successfully!"
271
+
272
+ except json.JSONDecodeError as e:
273
+ return None, f"❌ JSON parsing error: {str(e)}"
274
+ except Exception as e:
275
+ return None, f"❌ Error: {str(e)}"
276
+
277
+
278
+ def gradio_simple_segment(
279
+ video_file,
280
+ point_x: int,
281
+ point_y: int,
282
+ frame_idx: int = 0,
283
+ remove_bg: bool = True,
284
+ max_frames: Optional[int] = 300
285
+ ):
286
+ """
287
+ Simple Gradio interface with single point annotation.
288
+ """
289
+ try:
290
+ # Create simple annotation
291
+ annotations = [{
292
+ "frame_idx": frame_idx,
293
+ "object_id": 1,
294
+ "points": [[point_x, point_y]],
295
+ "labels": [1]
296
+ }]
297
+
298
+ # Process video
299
+ output_path = segment_video(
300
+ video_path=video_file,
301
+ annotations=annotations,
302
+ remove_background=remove_bg,
303
+ max_frames=max_frames
304
+ )
305
+
306
+ return output_path, f"✅ Video processed! Tracked from point ({point_x}, {point_y}) on frame {frame_idx}"
307
+
308
+ except Exception as e:
309
+ return None, f"❌ Error: {str(e)}"
310
+
311
+
312
+ # ============================================================================
313
+ # API ENDPOINTS (via Gradio API)
314
+ # ============================================================================
315
+
316
+ def api_segment_video(video_file, annotations_json: str, remove_background: bool = True, max_frames: int = None):
317
+ """
318
+ API endpoint for video segmentation.
319
+ Can be called via gradio_client or direct HTTP requests.
320
+ """
321
+ annotations = json.loads(annotations_json)
322
+ output_path = segment_video(video_file, annotations, remove_background, max_frames)
323
+ return output_path
324
+
325
+
326
+ # ============================================================================
327
+ # CREATE GRADIO APP
328
+ # ============================================================================
329
+
330
+ def create_interface():
331
+ """Create the Gradio interface."""
332
+
333
+ # Initialize model
334
+ initialize_model()
335
+
336
+ # Create tabs for different interfaces
337
+ with gr.Blocks(title="SAM2 Video Segmentation - Remove Background") as app:
338
+ gr.Markdown("""
339
+ # 🎥 SAM2 Video Background Remover
340
+
341
+ Remove backgrounds from videos by tracking objects. Uses Meta's Segment Anything Model 2 (SAM2).
342
+
343
+ **Two ways to use this:**
344
+ 1. **Simple Mode**: Click on an object in the first frame
345
+ 2. **Advanced Mode**: Provide detailed JSON annotations
346
+ 3. **API Mode**: Use the API endpoint programmatically
347
+ """)
348
+
349
+ with gr.Tabs():
350
+ # ===================== SIMPLE MODE =====================
351
+ with gr.Tab("Simple Mode"):
352
+ gr.Markdown("""
353
+ ### Quick Start
354
+ 1. Upload a video
355
+ 2. Specify the coordinates of the object you want to track
356
+ 3. Click "Process Video"
357
+
358
+ **Tip**: Open your video in an image viewer to find the x,y coordinates of your target object in the first frame.
359
+ """)
360
+
361
+ with gr.Row():
362
+ with gr.Column():
363
+ simple_video_input = gr.Video(label="Upload Video")
364
+
365
+ with gr.Row():
366
+ point_x_input = gr.Number(label="Point X", value=320, precision=0)
367
+ point_y_input = gr.Number(label="Point Y", value=240, precision=0)
368
+
369
+ frame_idx_input = gr.Number(label="Frame Index", value=0, precision=0,
370
+ info="Which frame to annotate (usually 0 for first frame)")
371
+
372
+ remove_bg_simple = gr.Checkbox(label="Remove Background", value=True,
373
+ info="If checked, removes background. If unchecked, highlights object.")
374
+
375
+ max_frames_simple = gr.Number(label="Max Frames (optional)", value=300, precision=0,
376
+ info="Limit frames for faster processing. Leave at 0 for all frames.")
377
+
378
+ simple_btn = gr.Button("🎬 Process Video", variant="primary")
379
+
380
+ with gr.Column():
381
+ simple_output_video = gr.Video(label="Output Video")
382
+ simple_status = gr.Textbox(label="Status", lines=3)
383
+
384
+ simple_btn.click(
385
+ fn=gradio_simple_segment,
386
+ inputs=[simple_video_input, point_x_input, point_y_input, frame_idx_input,
387
+ remove_bg_simple, max_frames_simple],
388
+ outputs=[simple_output_video, simple_status]
389
+ )
390
+
391
+ gr.Markdown("""
392
+ ### Example:
393
+ For a 640x480 video with a person in the center, try: X=320, Y=240, Frame=0
394
+ """)
395
+
396
+ # ===================== ADVANCED MODE =====================
397
+ with gr.Tab("Advanced Mode (JSON)"):
398
+ gr.Markdown("""
399
+ ### Advanced Annotations
400
+ Provide detailed JSON annotations for multiple objects and frames.
401
+
402
+ **JSON Format:**
403
+ ```json
404
+ [
405
+ {
406
+ "frame_idx": 0,
407
+ "object_id": 1,
408
+ "points": [[x1, y1], [x2, y2]],
409
+ "labels": [1, 1]
410
+ }
411
+ ]
412
+ ```
413
+
414
+ - `frame_idx`: Frame number to annotate
415
+ - `object_id`: Unique ID for each object (1, 2, 3, ...)
416
+ - `points`: List of [x, y] coordinates
417
+ - `labels`: 1 for foreground point, 0 for background point
418
+ """)
419
+
420
+ with gr.Row():
421
+ with gr.Column():
422
+ adv_video_input = gr.Video(label="Upload Video")
423
+
424
+ adv_annotation_input = gr.Textbox(
425
+ label="Annotations (JSON)",
426
+ lines=10,
427
+ value='''[
428
+ {
429
+ "frame_idx": 0,
430
+ "object_id": 1,
431
+ "points": [[320, 240]],
432
+ "labels": [1]
433
+ }
434
+ ]''',
435
+ placeholder="Enter JSON annotations here..."
436
+ )
437
+
438
+ remove_bg_adv = gr.Checkbox(label="Remove Background", value=True)
439
+ max_frames_adv = gr.Number(label="Max Frames (0 = all)", value=0, precision=0)
440
+
441
+ adv_btn = gr.Button("🎬 Process Video", variant="primary")
442
+
443
+ with gr.Column():
444
+ adv_output_video = gr.Video(label="Output Video")
445
+ adv_status = gr.Textbox(label="Status", lines=3)
446
+
447
+ adv_btn.click(
448
+ fn=gradio_segment_video,
449
+ inputs=[adv_video_input, adv_annotation_input, remove_bg_adv, max_frames_adv],
450
+ outputs=[adv_output_video, adv_status]
451
+ )
452
+
453
+ # ===================== API INFO =====================
454
+ with gr.Tab("API Documentation"):
455
+ gr.Markdown("""
456
+ ## 📡 API Usage
457
+
458
+ This Space exposes an API that you can call programmatically.
459
+
460
+ ### Using Python with `gradio_client`
461
+
462
+ ```python
463
+ from gradio_client import Client
464
+ import json
465
+
466
+ # Connect to the Space
467
+ client = Client("YOUR_USERNAME/YOUR_SPACE_NAME")
468
+
469
+ # Define annotations
470
+ annotations = [
471
+ {
472
+ "frame_idx": 0,
473
+ "object_id": 1,
474
+ "points": [[320, 240]],
475
+ "labels": [1]
476
+ }
477
+ ]
478
+
479
+ # Call the API
480
+ result = client.predict(
481
+ video_file="path/to/video.mp4",
482
+ annotations_json=json.dumps(annotations),
483
+ remove_background=True,
484
+ max_frames=300,
485
+ api_name="/segment_video_api"
486
+ )
487
+
488
+ print(f"Output video: {result}")
489
+ ```
490
+
491
+ ### Using cURL
492
+
493
+ ```bash
494
+ curl -X POST https://YOUR_USERNAME-YOUR_SPACE_NAME.hf.space/api/predict \\
495
+ -H "Content-Type: application/json" \\
496
+ -F "data=@video.mp4" \\
497
+ -F 'annotations=[{"frame_idx":0,"object_id":1,"points":[[320,240]],"labels":[1]}]'
498
+ ```
499
+
500
+ ### Parameters
501
+
502
+ - **video_file**: Video file (required)
503
+ - **annotations_json**: JSON string with annotations (required)
504
+ - **remove_background**: Boolean (default: true)
505
+ - **max_frames**: Integer (default: null, processes all frames)
506
+
507
+ ### Response
508
+
509
+ Returns the path to the processed video file.
510
+ """)
511
+
512
+ # Add API endpoint
513
+ api_interface = gr.Interface(
514
+ fn=api_segment_video,
515
+ inputs=[
516
+ gr.Video(label="Video File"),
517
+ gr.Textbox(label="Annotations JSON"),
518
+ gr.Checkbox(label="Remove Background", value=True),
519
+ gr.Number(label="Max Frames", value=None, precision=0)
520
+ ],
521
+ outputs=gr.Video(label="Output Video"),
522
+ api_name="segment_video_api",
523
+ visible=False # Hidden from UI, only accessible via API
524
+ )
525
+
526
+ return app
527
+
528
+
529
+ # ============================================================================
530
+ # LAUNCH
531
+ # ============================================================================
532
+
533
+ if __name__ == "__main__":
534
+ app = create_interface()
535
+ app.launch(
536
+ server_name="0.0.0.0",
537
+ server_port=7860,
538
+ share=False
539
+ )
540
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers>=4.57.0
2
+ torch>=2.0.0
3
+ gradio>=4.0.0
4
+ opencv-python-headless>=4.8.0
5
+ numpy>=1.24.0
6
+ Pillow>=10.0.0
7
+ accelerate>=0.20.0
8
+