Update app.py
Browse files
app.py
CHANGED
|
@@ -341,7 +341,7 @@ def track_video(n_frames, video_state):
|
|
| 341 |
|
| 342 |
images = [cv2.resize(img, (W_, H_)) for img in images]
|
| 343 |
video_state["origin_images"] = images
|
| 344 |
-
|
| 345 |
|
| 346 |
sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
|
| 347 |
config = "sam2_hiera_l.yaml"
|
|
@@ -350,19 +350,19 @@ def track_video(n_frames, video_state):
|
|
| 350 |
)
|
| 351 |
|
| 352 |
inference_state = video_predictor_local.init_state(
|
| 353 |
-
images=
|
| 354 |
)
|
| 355 |
|
| 356 |
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
|
| 357 |
-
|
| 358 |
else:
|
| 359 |
-
|
| 360 |
|
| 361 |
video_predictor_local.add_new_mask(
|
| 362 |
inference_state=inference_state,
|
| 363 |
frame_idx=0,
|
| 364 |
obj_id=obj_id,
|
| 365 |
-
mask=
|
| 366 |
)
|
| 367 |
|
| 368 |
output_frames = []
|
|
@@ -375,7 +375,7 @@ def track_video(n_frames, video_state):
|
|
| 375 |
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(
|
| 376 |
inference_state
|
| 377 |
):
|
| 378 |
-
frame =
|
| 379 |
mask = np.zeros((H, W, 3), dtype=np.float32)
|
| 380 |
for i, logit in enumerate(out_mask_logits):
|
| 381 |
out_mask = logit.cpu().squeeze().detach().numpy()
|
|
@@ -388,8 +388,6 @@ def track_video(n_frames, video_state):
|
|
| 388 |
painted = np.uint8(np.clip(painted * 255, 0, 255))
|
| 389 |
output_frames.append(painted)
|
| 390 |
|
| 391 |
-
video_state["masks"] = mask_frames
|
| 392 |
-
|
| 393 |
video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
|
| 394 |
clip = ImageSequenceClip(output_frames, fps=15)
|
| 395 |
clip.write_videofile(
|
|
@@ -404,7 +402,7 @@ def track_video(n_frames, video_state):
|
|
| 404 |
except Exception as e:
|
| 405 |
print("Error checking video file:", repr(e))
|
| 406 |
|
| 407 |
-
return video_file,
|
| 408 |
|
| 409 |
|
| 410 |
@spaces.GPU(duration=150)
|
|
@@ -415,10 +413,11 @@ def inference_and_return_video(
|
|
| 415 |
ref_patch_ratio,
|
| 416 |
fg_threshold,
|
| 417 |
seed,
|
| 418 |
-
|
|
|
|
| 419 |
ref_state,
|
| 420 |
):
|
| 421 |
-
if
|
| 422 |
print("No video frames or video masks.")
|
| 423 |
return None, None, None
|
| 424 |
|
|
@@ -426,11 +425,11 @@ def inference_and_return_video(
|
|
| 426 |
print("Reference image or reference mask missing.")
|
| 427 |
return None, None, None
|
| 428 |
|
| 429 |
-
images =
|
| 430 |
-
masks =
|
| 431 |
|
| 432 |
-
|
| 433 |
-
|
| 434 |
for img, msk in zip(images, masks):
|
| 435 |
if not isinstance(img, np.ndarray):
|
| 436 |
img = np.asarray(img)
|
|
@@ -447,10 +446,10 @@ def inference_and_return_video(
|
|
| 447 |
m2 = (m2 > 0.5).astype(np.uint8) * 255
|
| 448 |
msk_pil = Image.fromarray(m2, mode="L")
|
| 449 |
|
| 450 |
-
|
| 451 |
-
|
| 452 |
|
| 453 |
-
num_frames = len(
|
| 454 |
|
| 455 |
h0, w0 = images[0].shape[:2]
|
| 456 |
if h0 > w0:
|
|
@@ -470,8 +469,8 @@ def inference_and_return_video(
|
|
| 470 |
pipe.to("cuda")
|
| 471 |
with torch.no_grad():
|
| 472 |
retex_frames, mesh_frames, ref_img_out = pipe(
|
| 473 |
-
video=
|
| 474 |
-
mask=
|
| 475 |
reference_image=ref_img_pil,
|
| 476 |
reference_mask=ref_mask_pil,
|
| 477 |
conditioning_scale=1.0,
|
|
@@ -608,6 +607,9 @@ with gr.Blocks() as demo:
|
|
| 608 |
}
|
| 609 |
)
|
| 610 |
|
|
|
|
|
|
|
|
|
|
| 611 |
gr.Markdown(f"<div style='text-align:center;'>{text}</div>")
|
| 612 |
|
| 613 |
with gr.Column():
|
|
@@ -754,7 +756,8 @@ with gr.Blocks() as demo:
|
|
| 754 |
ref_patch_slider,
|
| 755 |
fg_threshold_slider,
|
| 756 |
seed_slider,
|
| 757 |
-
|
|
|
|
| 758 |
ref_state,
|
| 759 |
],
|
| 760 |
outputs=[remove_video, mesh_video, ref_image_final],
|
|
@@ -777,7 +780,7 @@ with gr.Blocks() as demo:
|
|
| 777 |
track_btn.click(
|
| 778 |
track_video,
|
| 779 |
inputs=[n_frames_slider, video_state],
|
| 780 |
-
outputs=[video_output,
|
| 781 |
)
|
| 782 |
|
| 783 |
ref_image_input.change(
|
|
|
|
| 341 |
|
| 342 |
images = [cv2.resize(img, (W_, H_)) for img in images]
|
| 343 |
video_state["origin_images"] = images
|
| 344 |
+
images_np = np.array(images)
|
| 345 |
|
| 346 |
sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
|
| 347 |
config = "sam2_hiera_l.yaml"
|
|
|
|
| 350 |
)
|
| 351 |
|
| 352 |
inference_state = video_predictor_local.init_state(
|
| 353 |
+
images=images_np / 255, device="cuda"
|
| 354 |
)
|
| 355 |
|
| 356 |
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
|
| 357 |
+
mask0 = torch.from_numpy(video_state["masks"][0])[:, :, 0]
|
| 358 |
else:
|
| 359 |
+
mask0 = torch.from_numpy(video_state["masks"][0])
|
| 360 |
|
| 361 |
video_predictor_local.add_new_mask(
|
| 362 |
inference_state=inference_state,
|
| 363 |
frame_idx=0,
|
| 364 |
obj_id=obj_id,
|
| 365 |
+
mask=mask0,
|
| 366 |
)
|
| 367 |
|
| 368 |
output_frames = []
|
|
|
|
| 375 |
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(
|
| 376 |
inference_state
|
| 377 |
):
|
| 378 |
+
frame = images_np[out_frame_idx].astype(np.float32) / 255.0
|
| 379 |
mask = np.zeros((H, W, 3), dtype=np.float32)
|
| 380 |
for i, logit in enumerate(out_mask_logits):
|
| 381 |
out_mask = logit.cpu().squeeze().detach().numpy()
|
|
|
|
| 388 |
painted = np.uint8(np.clip(painted * 255, 0, 255))
|
| 389 |
output_frames.append(painted)
|
| 390 |
|
|
|
|
|
|
|
| 391 |
video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
|
| 392 |
clip = ImageSequenceClip(output_frames, fps=15)
|
| 393 |
clip.write_videofile(
|
|
|
|
| 402 |
except Exception as e:
|
| 403 |
print("Error checking video file:", repr(e))
|
| 404 |
|
| 405 |
+
return video_file, images, mask_frames
|
| 406 |
|
| 407 |
|
| 408 |
@spaces.GPU(duration=150)
|
|
|
|
| 413 |
ref_patch_ratio,
|
| 414 |
fg_threshold,
|
| 415 |
seed,
|
| 416 |
+
video_frames,
|
| 417 |
+
mask_frames,
|
| 418 |
ref_state,
|
| 419 |
):
|
| 420 |
+
if video_frames is None or mask_frames is None:
|
| 421 |
print("No video frames or video masks.")
|
| 422 |
return None, None, None
|
| 423 |
|
|
|
|
| 425 |
print("Reference image or reference mask missing.")
|
| 426 |
return None, None, None
|
| 427 |
|
| 428 |
+
images = video_frames
|
| 429 |
+
masks = mask_frames
|
| 430 |
|
| 431 |
+
video_frames_pil = []
|
| 432 |
+
mask_frames_pil = []
|
| 433 |
for img, msk in zip(images, masks):
|
| 434 |
if not isinstance(img, np.ndarray):
|
| 435 |
img = np.asarray(img)
|
|
|
|
| 446 |
m2 = (m2 > 0.5).astype(np.uint8) * 255
|
| 447 |
msk_pil = Image.fromarray(m2, mode="L")
|
| 448 |
|
| 449 |
+
video_frames_pil.append(img_pil)
|
| 450 |
+
mask_frames_pil.append(msk_pil)
|
| 451 |
|
| 452 |
+
num_frames = len(video_frames_pil)
|
| 453 |
|
| 454 |
h0, w0 = images[0].shape[:2]
|
| 455 |
if h0 > w0:
|
|
|
|
| 469 |
pipe.to("cuda")
|
| 470 |
with torch.no_grad():
|
| 471 |
retex_frames, mesh_frames, ref_img_out = pipe(
|
| 472 |
+
video=video_frames_pil,
|
| 473 |
+
mask=mask_frames_pil,
|
| 474 |
reference_image=ref_img_pil,
|
| 475 |
reference_mask=ref_mask_pil,
|
| 476 |
conditioning_scale=1.0,
|
|
|
|
| 607 |
}
|
| 608 |
)
|
| 609 |
|
| 610 |
+
video_frames_state = gr.State(None)
|
| 611 |
+
mask_frames_state = gr.State(None)
|
| 612 |
+
|
| 613 |
gr.Markdown(f"<div style='text-align:center;'>{text}</div>")
|
| 614 |
|
| 615 |
with gr.Column():
|
|
|
|
| 756 |
ref_patch_slider,
|
| 757 |
fg_threshold_slider,
|
| 758 |
seed_slider,
|
| 759 |
+
video_frames_state,
|
| 760 |
+
mask_frames_state,
|
| 761 |
ref_state,
|
| 762 |
],
|
| 763 |
outputs=[remove_video, mesh_video, ref_image_final],
|
|
|
|
| 780 |
track_btn.click(
|
| 781 |
track_video,
|
| 782 |
inputs=[n_frames_slider, video_state],
|
| 783 |
+
outputs=[video_output, video_frames_state, mask_frames_state],
|
| 784 |
)
|
| 785 |
|
| 786 |
ref_image_input.change(
|