Ryan-PR commited on
Commit
103eee6
·
verified ·
1 Parent(s): f77a896

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -22
app.py CHANGED
@@ -341,7 +341,7 @@ def track_video(n_frames, video_state):
341
 
342
  images = [cv2.resize(img, (W_, H_)) for img in images]
343
  video_state["origin_images"] = images
344
- images = np.array(images)
345
 
346
  sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
347
  config = "sam2_hiera_l.yaml"
@@ -350,19 +350,19 @@ def track_video(n_frames, video_state):
350
  )
351
 
352
  inference_state = video_predictor_local.init_state(
353
- images=images / 255, device="cuda"
354
  )
355
 
356
  if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
357
- mask = torch.from_numpy(video_state["masks"][0])[:, :, 0]
358
  else:
359
- mask = torch.from_numpy(video_state["masks"][0])
360
 
361
  video_predictor_local.add_new_mask(
362
  inference_state=inference_state,
363
  frame_idx=0,
364
  obj_id=obj_id,
365
- mask=mask,
366
  )
367
 
368
  output_frames = []
@@ -375,7 +375,7 @@ def track_video(n_frames, video_state):
375
  for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(
376
  inference_state
377
  ):
378
- frame = images[out_frame_idx].astype(np.float32) / 255.0
379
  mask = np.zeros((H, W, 3), dtype=np.float32)
380
  for i, logit in enumerate(out_mask_logits):
381
  out_mask = logit.cpu().squeeze().detach().numpy()
@@ -388,8 +388,6 @@ def track_video(n_frames, video_state):
388
  painted = np.uint8(np.clip(painted * 255, 0, 255))
389
  output_frames.append(painted)
390
 
391
- video_state["masks"] = mask_frames
392
-
393
  video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
394
  clip = ImageSequenceClip(output_frames, fps=15)
395
  clip.write_videofile(
@@ -404,7 +402,7 @@ def track_video(n_frames, video_state):
404
  except Exception as e:
405
  print("Error checking video file:", repr(e))
406
 
407
- return video_file, video_state
408
 
409
 
410
  @spaces.GPU(duration=150)
@@ -415,10 +413,11 @@ def inference_and_return_video(
415
  ref_patch_ratio,
416
  fg_threshold,
417
  seed,
418
- video_state,
 
419
  ref_state,
420
  ):
421
- if video_state["origin_images"] is None or video_state["masks"] is None:
422
  print("No video frames or video masks.")
423
  return None, None, None
424
 
@@ -426,11 +425,11 @@ def inference_and_return_video(
426
  print("Reference image or reference mask missing.")
427
  return None, None, None
428
 
429
- images = video_state["origin_images"]
430
- masks = video_state["masks"]
431
 
432
- video_frames = []
433
- mask_frames = []
434
  for img, msk in zip(images, masks):
435
  if not isinstance(img, np.ndarray):
436
  img = np.asarray(img)
@@ -447,10 +446,10 @@ def inference_and_return_video(
447
  m2 = (m2 > 0.5).astype(np.uint8) * 255
448
  msk_pil = Image.fromarray(m2, mode="L")
449
 
450
- video_frames.append(img_pil)
451
- mask_frames.append(msk_pil)
452
 
453
- num_frames = len(video_frames)
454
 
455
  h0, w0 = images[0].shape[:2]
456
  if h0 > w0:
@@ -470,8 +469,8 @@ def inference_and_return_video(
470
  pipe.to("cuda")
471
  with torch.no_grad():
472
  retex_frames, mesh_frames, ref_img_out = pipe(
473
- video=video_frames,
474
- mask=mask_frames,
475
  reference_image=ref_img_pil,
476
  reference_mask=ref_mask_pil,
477
  conditioning_scale=1.0,
@@ -608,6 +607,9 @@ with gr.Blocks() as demo:
608
  }
609
  )
610
 
 
 
 
611
  gr.Markdown(f"<div style='text-align:center;'>{text}</div>")
612
 
613
  with gr.Column():
@@ -754,7 +756,8 @@ with gr.Blocks() as demo:
754
  ref_patch_slider,
755
  fg_threshold_slider,
756
  seed_slider,
757
- video_state,
 
758
  ref_state,
759
  ],
760
  outputs=[remove_video, mesh_video, ref_image_final],
@@ -777,7 +780,7 @@ with gr.Blocks() as demo:
777
  track_btn.click(
778
  track_video,
779
  inputs=[n_frames_slider, video_state],
780
- outputs=[video_output, video_state],
781
  )
782
 
783
  ref_image_input.change(
 
341
 
342
  images = [cv2.resize(img, (W_, H_)) for img in images]
343
  video_state["origin_images"] = images
344
+ images_np = np.array(images)
345
 
346
  sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
347
  config = "sam2_hiera_l.yaml"
 
350
  )
351
 
352
  inference_state = video_predictor_local.init_state(
353
+ images=images_np / 255, device="cuda"
354
  )
355
 
356
  if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
357
+ mask0 = torch.from_numpy(video_state["masks"][0])[:, :, 0]
358
  else:
359
+ mask0 = torch.from_numpy(video_state["masks"][0])
360
 
361
  video_predictor_local.add_new_mask(
362
  inference_state=inference_state,
363
  frame_idx=0,
364
  obj_id=obj_id,
365
+ mask=mask0,
366
  )
367
 
368
  output_frames = []
 
375
  for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(
376
  inference_state
377
  ):
378
+ frame = images_np[out_frame_idx].astype(np.float32) / 255.0
379
  mask = np.zeros((H, W, 3), dtype=np.float32)
380
  for i, logit in enumerate(out_mask_logits):
381
  out_mask = logit.cpu().squeeze().detach().numpy()
 
388
  painted = np.uint8(np.clip(painted * 255, 0, 255))
389
  output_frames.append(painted)
390
 
 
 
391
  video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
392
  clip = ImageSequenceClip(output_frames, fps=15)
393
  clip.write_videofile(
 
402
  except Exception as e:
403
  print("Error checking video file:", repr(e))
404
 
405
+ return video_file, images, mask_frames
406
 
407
 
408
  @spaces.GPU(duration=150)
 
413
  ref_patch_ratio,
414
  fg_threshold,
415
  seed,
416
+ video_frames,
417
+ mask_frames,
418
  ref_state,
419
  ):
420
+ if video_frames is None or mask_frames is None:
421
  print("No video frames or video masks.")
422
  return None, None, None
423
 
 
425
  print("Reference image or reference mask missing.")
426
  return None, None, None
427
 
428
+ images = video_frames
429
+ masks = mask_frames
430
 
431
+ video_frames_pil = []
432
+ mask_frames_pil = []
433
  for img, msk in zip(images, masks):
434
  if not isinstance(img, np.ndarray):
435
  img = np.asarray(img)
 
446
  m2 = (m2 > 0.5).astype(np.uint8) * 255
447
  msk_pil = Image.fromarray(m2, mode="L")
448
 
449
+ video_frames_pil.append(img_pil)
450
+ mask_frames_pil.append(msk_pil)
451
 
452
+ num_frames = len(video_frames_pil)
453
 
454
  h0, w0 = images[0].shape[:2]
455
  if h0 > w0:
 
469
  pipe.to("cuda")
470
  with torch.no_grad():
471
  retex_frames, mesh_frames, ref_img_out = pipe(
472
+ video=video_frames_pil,
473
+ mask=mask_frames_pil,
474
  reference_image=ref_img_pil,
475
  reference_mask=ref_mask_pil,
476
  conditioning_scale=1.0,
 
607
  }
608
  )
609
 
610
+ video_frames_state = gr.State(None)
611
+ mask_frames_state = gr.State(None)
612
+
613
  gr.Markdown(f"<div style='text-align:center;'>{text}</div>")
614
 
615
  with gr.Column():
 
756
  ref_patch_slider,
757
  fg_threshold_slider,
758
  seed_slider,
759
+ video_frames_state,
760
+ mask_frames_state,
761
  ref_state,
762
  ],
763
  outputs=[remove_video, mesh_video, ref_image_final],
 
780
  track_btn.click(
781
  track_video,
782
  inputs=[n_frames_slider, video_state],
783
+ outputs=[video_output, video_frames_state, mask_frames_state],
784
  )
785
 
786
  ref_image_input.change(