prithivMLmods commited on
Commit
ec6fe6f
·
verified ·
1 Parent(s): c234825

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -5
app.py CHANGED
@@ -17,7 +17,6 @@ from transformers import (
17
  Sam3VideoModel, Sam3VideoProcessor
18
  )
19
 
20
- # --- THEME CONFIGURATION ---
21
  colors.steel_blue = colors.Color(
22
  name="steel_blue",
23
  c50="#EBF3F8",
@@ -80,7 +79,6 @@ class CustomBlueTheme(Soft):
80
 
81
  app_theme = CustomBlueTheme()
82
 
83
- # --- MODEL MANAGEMENT & UTILS ---
84
  MODEL_CACHE = {}
85
  device = "cuda" if torch.cuda.is_available() else "cpu"
86
  print(f"Using compute device: {device}")
@@ -103,7 +101,6 @@ def load_segmentation_model(model_key):
103
 
104
  try:
105
  if model_key == "img_seg_model":
106
- # Using generic internal names
107
  seg_model = Sam3Model.from_pretrained("facebook/sam3").to(device)
108
  seg_processor = Sam3Processor.from_pretrained("facebook/sam3")
109
  MODEL_CACHE[model_key] = (seg_model, seg_processor)
@@ -185,7 +182,8 @@ def run_image_segmentation(source_img, text_query, conf_thresh=0.5):
185
  mask_threshold=0.5,
186
  target_sizes=model_inputs.get("original_sizes").tolist()
187
  )[0]
188
-
 
189
  annotation_list = []
190
  raw_masks = processed_results['masks'].cpu().numpy()
191
  raw_scores = processed_results['scores'].cpu().numpy()
@@ -250,7 +248,6 @@ def run_video_segmentation(source_vid, text_query, frame_limit, time_limit):
250
  except Exception as e:
251
  return None, f"Error during video processing: {str(e)}"
252
 
253
- # --- GUI ---
254
  custom_css="""
255
  #col-container { margin: 0 auto; max-width: 1100px; }
256
  #main-title h1 { font-size: 2.1em !important; }
@@ -274,6 +271,17 @@ with gr.Blocks(css=custom_css, theme=app_theme) as main_interface:
274
  with gr.Column(scale=1.5):
275
  image_result = gr.AnnotatedImage(label="Segmented Result", height=450)
276
 
 
 
 
 
 
 
 
 
 
 
 
277
  btn_process_img.click(
278
  fn=run_image_segmentation,
279
  inputs=[image_input, txt_prompt_img, conf_slider],
@@ -296,6 +304,17 @@ with gr.Blocks(css=custom_css, theme=app_theme) as main_interface:
296
  video_result = gr.Video(label="Processed Video")
297
  process_status = gr.Textbox(label="System Status", interactive=False)
298
 
 
 
 
 
 
 
 
 
 
 
 
299
  btn_process_vid.click(
300
  run_video_segmentation,
301
  inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
 
17
  Sam3VideoModel, Sam3VideoProcessor
18
  )
19
 
 
20
  colors.steel_blue = colors.Color(
21
  name="steel_blue",
22
  c50="#EBF3F8",
 
79
 
80
  app_theme = CustomBlueTheme()
81
 
 
82
  MODEL_CACHE = {}
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
84
  print(f"Using compute device: {device}")
 
101
 
102
  try:
103
  if model_key == "img_seg_model":
 
104
  seg_model = Sam3Model.from_pretrained("facebook/sam3").to(device)
105
  seg_processor = Sam3Processor.from_pretrained("facebook/sam3")
106
  MODEL_CACHE[model_key] = (seg_model, seg_processor)
 
182
  mask_threshold=0.5,
183
  target_sizes=model_inputs.get("original_sizes").tolist()
184
  )[0]
185
+
186
+ # Use AnnotatedImage format
187
  annotation_list = []
188
  raw_masks = processed_results['masks'].cpu().numpy()
189
  raw_scores = processed_results['scores'].cpu().numpy()
 
248
  except Exception as e:
249
  return None, f"Error during video processing: {str(e)}"
250
 
 
251
  custom_css="""
252
  #col-container { margin: 0 auto; max-width: 1100px; }
253
  #main-title h1 { font-size: 2.1em !important; }
 
271
  with gr.Column(scale=1.5):
272
  image_result = gr.AnnotatedImage(label="Segmented Result", height=450)
273
 
274
+ gr.Examples(
275
+ examples=[
276
+ ["examples/player.jpg", "player in white", 0.5],
277
+ ],
278
+ inputs=[image_input, txt_prompt_img, conf_slider],
279
+ outputs=[image_result],
280
+ fn=run_image_segmentation,
281
+ cache_examples=False,
282
+ label="Image Examples"
283
+ )
284
+
285
  btn_process_img.click(
286
  fn=run_image_segmentation,
287
  inputs=[image_input, txt_prompt_img, conf_slider],
 
304
  video_result = gr.Video(label="Processed Video")
305
  process_status = gr.Textbox(label="System Status", interactive=False)
306
 
307
+ gr.Examples(
308
+ examples=[
309
+ ["examples/sample_video.mp4", "ball", 60, 60],
310
+ ],
311
+ inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
312
+ outputs=[video_result, process_status],
313
+ fn=run_video_segmentation,
314
+ cache_examples=False,
315
+ label="Video Examples"
316
+ )
317
+
318
  btn_process_vid.click(
319
  run_video_segmentation,
320
  inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],