Refacade / app.py
Ryan-PR's picture
Update app.py
687894c verified
import os
import time
import random
import gradio as gr
import cv2
import numpy as np
from PIL import Image
import subprocess
import importlib
def ensure_wan():
try:
import wan # noqa
print("[setup] wan already installed.")
except ImportError:
cmd = "pip install --no-build-isolation 'wan@git+https://github.com/Wan-Video/Wan2.1'"
env = dict(os.environ)
print(f"[setup] Installing wan2.1: {cmd}")
subprocess.run(cmd, shell=True, check=True, env=env)
def ensure_flash_attn():
try:
import flash_attn # noqa
from flash_attn.flash_attn_interface import flash_attn_func # type: ignore
print("[setup] flash-attn seems OK.")
return
except Exception as e:
print("[setup] flash-attn broken, will rebuild from source:", repr(e))
cmd = (
"pip uninstall -y flash-attn flash_attn || true && "
"pip install flash-attn==2.7.2.post1 --no-build-isolation"
)
print(f"[setup] Rebuilding flash-attn: {cmd}")
subprocess.run(cmd, shell=True, check=True)
importlib.invalidate_caches()
ensure_flash_attn()
ensure_wan()
os.makedirs("./sam2/SAM2-Video-Predictor/checkpoints/", exist_ok=True)
from huggingface_hub import snapshot_download
def download_sam2():
snapshot_download(
repo_id="facebook/sam2-hiera-large",
local_dir="./sam2/SAM2-Video-Predictor/checkpoints/",
)
print("Download sam2 completed")
def download_refacade():
snapshot_download(
repo_id="fishze/Refacade",
local_dir="./models/",
)
print("Download refacade completed")
download_sam2()
download_refacade()
import torch
import torch.nn.functional as F
from decord import VideoReader, cpu
from moviepy.editor import ImageSequenceClip
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
import spaces
from pipeline import RefacadePipeline
from vace.models.wan.modules.model_mm import VaceMMModel
from vace.models.wan.modules.model_tr import VaceWanModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from wan.text2video import FlowUniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image, load_video
from vae import WanVAE
COLOR_PALETTE = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(255, 128, 0),
(128, 0, 255),
(0, 128, 255),
(128, 255, 0),
]
video_length = 81
W = 1024
H = W
device = "cuda"
sam_device = "cpu"
def get_pipe_image_and_video_predictor():
vae = WanVAE(
vae_pth="./models/vae/Wan2.1_VAE.pth",
dtype=torch.float16,
)
pipe_device = "cuda"
texture_remover = VaceWanModel.from_config(
"./models/texture_remover/texture_remover.json"
)
ckpt = torch.load(
"./models/texture_remover/texture_remover.pth",
map_location="cpu",
)
texture_remover.load_state_dict(ckpt)
texture_remover = texture_remover.to(dtype=torch.float16, device=pipe_device)
model = VaceMMModel.from_config(
"./models/refacade/refacade.json"
)
ckpt = torch.load(
"./models/refacade/refacade.pth",
map_location="cpu",
)
model.load_state_dict(ckpt)
model = model.to(dtype=torch.float16, device=pipe_device)
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=1000,
shift=1,
)
pipe = RefacadePipeline(
vae=vae,
transformer=model,
texture_remover=texture_remover,
scheduler=sample_scheduler,
)
pipe.to(pipe_device)
sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
config = "sam2_hiera_l.yaml"
video_predictor = build_sam2_video_predictor(config, sam2_checkpoint, device=sam_device)
model_sam = build_sam2(config, sam2_checkpoint, device=sam_device)
model_sam.image_size = 1024
image_predictor = SAM2ImagePredictor(sam_model=model_sam)
return pipe, image_predictor, video_predictor
def get_video_info(video_path, video_state):
video_state["input_points"] = []
video_state["scaled_points"] = []
video_state["input_labels"] = []
video_state["frame_idx"] = 0
vr = VideoReader(video_path, ctx=cpu(0))
first_frame = vr[0].asnumpy()
del vr
if first_frame.shape[0] > first_frame.shape[1]:
W_ = W
H_ = int(W_ * first_frame.shape[0] / first_frame.shape[1])
else:
H_ = H
W_ = int(H_ * first_frame.shape[1] / first_frame.shape[0])
first_frame = cv2.resize(first_frame, (W_, H_))
video_state["origin_images"] = np.expand_dims(first_frame, axis=0)
video_state["inference_state"] = None
video_state["video_path"] = video_path
video_state["masks"] = None
video_state["painted_images"] = None
image = Image.fromarray(first_frame)
return image
def segment_frame(evt: gr.SelectData, label, video_state):
if video_state["origin_images"] is None:
return None
x, y = evt.index
new_point = [x, y]
label_value = 1 if label == "Positive" else 0
video_state["input_points"].append(new_point)
video_state["input_labels"].append(label_value)
height, width = video_state["origin_images"][0].shape[0:2]
scaled_points = []
for pt in video_state["input_points"]:
sx = pt[0] / width
sy = pt[1] / height
scaled_points.append([sx, sy])
video_state["scaled_points"] = scaled_points
image_predictor.set_image(video_state["origin_images"][0])
mask, _, _ = image_predictor.predict(
point_coords=video_state["scaled_points"],
point_labels=video_state["input_labels"],
multimask_output=False,
normalize_coords=False,
)
mask = np.squeeze(mask)
mask = cv2.resize(mask, (width, height))
mask = mask[:, :, None]
color = (
np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32)
/ 255.0
)
color = color[None, None, :]
org_image = video_state["origin_images"][0].astype(np.float32) / 255.0
painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color
painted_image = np.uint8(np.clip(painted_image * 255, 0, 255))
video_state["painted_images"] = np.expand_dims(painted_image, axis=0)
video_state["masks"] = np.expand_dims(mask[:, :, 0], axis=0)
for i in range(len(video_state["input_points"])):
point = video_state["input_points"][i]
if video_state["input_labels"][i] == 0:
cv2.circle(painted_image, point, radius=3, color=(0, 0, 255), thickness=-1)
else:
cv2.circle(painted_image, point, radius=3, color=(255, 0, 0), thickness=-1)
return Image.fromarray(painted_image)
def clear_clicks(video_state):
video_state["input_points"] = []
video_state["input_labels"] = []
video_state["scaled_points"] = []
video_state["inference_state"] = None
video_state["masks"] = None
video_state["painted_images"] = None
return (
Image.fromarray(video_state["origin_images"][0])
if video_state["origin_images"] is not None
else None
)
def set_ref_image(ref_img, ref_state):
if ref_img is None:
return None
if isinstance(ref_img, Image.Image):
img_np = np.array(ref_img)
else:
img_np = ref_img
ref_state["origin_image"] = img_np
ref_state["input_points"] = []
ref_state["input_labels"] = []
ref_state["scaled_points"] = []
ref_state["mask"] = None
return Image.fromarray(img_np)
def segment_ref_frame(evt: gr.SelectData, label, ref_state):
if ref_state["origin_image"] is None:
return None
x, y = evt.index
new_point = [x, y]
label_value = 1 if label == "Positive" else 0
ref_state["input_points"].append(new_point)
ref_state["input_labels"].append(label_value)
img = ref_state["origin_image"]
h, w = img.shape[:2]
scaled_points = []
for pt in ref_state["input_points"]:
sx = pt[0] / w
sy = pt[1] / h
scaled_points.append([sx, sy])
ref_state["scaled_points"] = scaled_points
image_predictor.set_image(img)
mask, _, _ = image_predictor.predict(
point_coords=scaled_points,
point_labels=ref_state["input_labels"],
multimask_output=False,
normalize_coords=False,
)
mask = np.squeeze(mask)
mask = cv2.resize(mask, (w, h))
mask = mask[:, :, None]
ref_state["mask"] = mask[:, :, 0]
color = (
np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32)
/ 255.0
)
color = color[None, None, :]
org_image = img.astype(np.float32) / 255.0
painted = (1 - mask * 0.5) * org_image + mask * 0.5 * color
painted = np.uint8(np.clip(painted * 255, 0, 255))
for i in range(len(ref_state["input_points"])):
point = ref_state["input_points"][i]
if ref_state["input_labels"][i] == 0:
cv2.circle(painted, point, radius=3, color=(0, 0, 255), thickness=-1)
else:
cv2.circle(painted, point, radius=3, color=(255, 0, 0), thickness=-1)
return Image.fromarray(painted)
def clear_ref_clicks(ref_state):
ref_state["input_points"] = []
ref_state["input_labels"] = []
ref_state["scaled_points"] = []
ref_state["mask"] = None
if ref_state["origin_image"] is None:
return None
return Image.fromarray(ref_state["origin_image"])
@spaces.GPU(duration=40)
@torch.no_grad()
def track_video(n_frames, video_state):
input_points = video_state["input_points"]
input_labels = video_state["input_labels"]
frame_idx = video_state["frame_idx"]
obj_id = video_state["obj_id"]
scaled_points = video_state["scaled_points"]
vr = VideoReader(video_state["video_path"], ctx=cpu(0))
height, width = vr[0].shape[0:2]
images = [vr[i].asnumpy() for i in range(min(len(vr), n_frames))]
del vr
if images[0].shape[0] > images[0].shape[1]:
W_ = W
H_ = int(W_ * images[0].shape[0] / images[0].shape[1])
else:
H_ = H
W_ = int(H_ * images[0].shape[1] / images[0].shape[0])
images = [cv2.resize(img, (W_, H_)) for img in images]
video_state["origin_images"] = images
images_np = np.array(images)
sam2_checkpoint = "./sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
config = "sam2_hiera_l.yaml"
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
video_predictor_local = build_sam2_video_predictor(
config, sam2_checkpoint, device="cuda"
)
inference_state = video_predictor_local.init_state(
images=images_np / 255, device="cuda"
)
if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
mask0 = torch.from_numpy(video_state["masks"][0])[:, :, 0]
else:
mask0 = torch.from_numpy(video_state["masks"][0])
video_predictor_local.add_new_mask(
inference_state=inference_state,
frame_idx=0,
obj_id=obj_id,
mask=mask0,
)
output_frames = []
mask_frames = []
color = (
np.array(
COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)],
dtype=np.float32,
)
/ 255.0
)
color = color[None, None, :]
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor_local.propagate_in_video(
inference_state
):
frame = images_np[out_frame_idx].astype(np.float32) / 255.0
mask = np.zeros((H, W, 3), dtype=np.float32)
for i, logit in enumerate(out_mask_logits):
out_mask = logit.cpu().squeeze().detach().numpy()
out_mask = (out_mask[:, :, None] > 0).astype(np.float32)
mask += out_mask
mask = np.clip(mask, 0, 1)
mask = cv2.resize(mask, (W_, H_))
mask_frames.append(mask)
painted = (1 - mask * 0.5) * frame + mask * 0.5 * color
painted = np.uint8(np.clip(painted * 255, 0, 255))
output_frames.append(painted)
video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
clip = ImageSequenceClip(output_frames, fps=15)
clip.write_videofile(
video_file, codec="libx264", audio=False, verbose=False, logger=None
)
print("Tracking done")
print("Tracking done, file:", video_file)
try:
exists = os.path.exists(video_file)
size = os.path.getsize(video_file) if exists else -1
print("File exists?", exists, "size:", size)
except Exception as e:
print("Error checking video file:", repr(e))
return video_file, images, mask_frames
@spaces.GPU(duration=100)
def inference_and_return_video(
dilate_radius,
num_inference_steps,
guidance_scale,
ref_patch_ratio,
fg_threshold,
seed,
video_frames,
mask_frames,
ref_state,
):
if video_frames is None or mask_frames is None:
print("No video frames or video masks.")
return None, None, None
if ref_state["origin_image"] is None or ref_state["mask"] is None:
print("Reference image or reference mask missing.")
return None, None, None
images = video_frames
masks = mask_frames
video_frames_pil = []
mask_frames_pil = []
for img, msk in zip(images, masks):
if not isinstance(img, np.ndarray):
img = np.asarray(img)
img_pil = Image.fromarray(img.astype(np.uint8))
if isinstance(msk, np.ndarray):
if msk.ndim == 3:
m2 = msk[..., 0]
else:
m2 = msk
else:
m2 = np.asarray(msk)
m2 = (m2 > 0.5).astype(np.uint8) * 255
msk_pil = Image.fromarray(m2, mode="L")
video_frames_pil.append(img_pil)
mask_frames_pil.append(msk_pil)
num_frames = len(video_frames_pil)
h0, w0 = images[0].shape[:2]
if h0 > w0:
height = 832
width = 480
else:
height = 480
width = 832
ref_img_np = ref_state["origin_image"]
ref_mask_np = ref_state["mask"]
ref_img_pil = Image.fromarray(ref_img_np.astype(np.uint8))
ref_mask_bin = (ref_mask_np > 0.5).astype(np.uint8) * 255
ref_mask_pil = Image.fromarray(ref_mask_bin, mode="L")
pipe.to("cuda")
with torch.no_grad():
retex_frames, mesh_frames, ref_img_out = pipe(
video=video_frames_pil,
mask=mask_frames_pil,
reference_image=ref_img_pil,
reference_mask=ref_mask_pil,
conditioning_scale=1.0,
height=height,
width=width,
num_frames=num_frames,
dilate_radius=int(dilate_radius),
num_inference_steps=int(num_inference_steps),
guidance_scale=float(guidance_scale),
reference_patch_ratio=float(ref_patch_ratio),
fg_thresh=float(fg_threshold),
generator=torch.Generator(device="cuda").manual_seed(seed),
return_dict=False,
)
retex_frames_uint8 = (np.clip(retex_frames[0], 0.0, 1.0) * 255).astype(np.uint8)
mesh_frames_uint8 = (np.clip(mesh_frames[0], 0.0, 1.0) * 255).astype(np.uint8)
retex_output_frames = [frame for frame in retex_frames_uint8]
mesh_output_frames = [frame for frame in mesh_frames_uint8]
if ref_img_out.dtype != np.uint8:
ref_img_out = (np.clip(ref_img_out, 0.0, 1.0) * 255).astype(np.uint8)
retex_video_file = f"/tmp/{time.time()}-{random.random()}-refacade_output.mp4"
retex_clip = ImageSequenceClip(retex_output_frames, fps=16)
retex_clip.write_videofile(
retex_video_file, codec="libx264", audio=False, verbose=False, logger=None
)
mesh_video_file = f"/tmp/{time.time()}-{random.random()}-mesh_output.mp4"
mesh_clip = ImageSequenceClip(mesh_output_frames, fps=16)
mesh_clip.write_videofile(
mesh_video_file, codec="libx264", audio=False, verbose=False, logger=None
)
ref_image_to_show = ref_img_out
return retex_video_file, mesh_video_file, ref_image_to_show
text = """
<div style='text-align:center; font-size:32px; font-family: Arial, Helvetica, sans-serif;'>
Refaçade Video Retexture Demo
</div>
<div style="display: flex; justify-content: center; align-items: center; gap: 10px; flex-wrap: nowrap;">
<a href="https://huggingface.co/fishze/Refacade"><img alt="Huggingface Model" src="https://img.shields.io/badge/%F0%9F%A4%97%20Huggingface-Model-brightgreen"></a>
<a href="https://github.com/fishZe233/Refacade"><img alt="Github" src="https://img.shields.io/badge/Refaçade-github-black"></a>
<a href="https://arxiv.org/abs/2512.04534"><img alt="arXiv" src="https://img.shields.io/badge/Refaçade-arXiv-b31b1b"></a>
<a href="https://refacade.github.io/"><img alt="Demo Page" src="https://img.shields.io/badge/Website-Demo%20Page-yellow"></a>
</div>
<div style='text-align:center; font-size:20px; margin-top: 10px; font-family: Arial, Helvetica, sans-serif;'>
Youze Huang<sup>*</sup>, Penghui Ruan<sup>*</sup>, Bojia Zi<sup>*</sup>, Xianbiao Qi<sup>†</sup>, Jianan Wang, Rong Xiao
</div>
<div style='text-align:center; font-size:14px; color: #888; margin-top: 5px; font-family: Arial, Helvetica, sans-serif;'>
<sup>*</sup> Equal contribution &nbsp; &nbsp; <sup>†</sup> Corresponding author
</div>
"""
pipe, image_predictor, video_predictor = get_pipe_image_and_video_predictor()
css = """
#my-btn {
width: 60% !important;
margin: 0 auto;
}
#my-video1 {
width: 60% !important;
height: 35% !important;
margin: 0 auto;
}
#my-video {
width: 60% !important;
height: 35% !important;
margin: 0 auto;
}
#my-md {
margin: 0 auto;
}
#my-btn2 {
width: 60% !important;
margin: 0 auto;
}
#my-btn2 button {
width: 120px !important;
max-width: 120px !important;
min-width: 120px !important;
height: 70px !important;
max-height: 70px !important;
min-height: 70px !important;
margin: 8px !important;
border-radius: 8px !important;
overflow: hidden !important;
white-space: normal !important;
}
#my-btn3 {
width: 60% !important;
margin: 0 auto;
}
#ref_title {
text-align: center;
}
#ref-image {
width: 60% !important;
height: 35% !important;
margin: 0 auto;
}
#ref-mask {
width: 60% !important;
height: 35% !important;
margin: 0 auto;
}
#mesh-row {
width: 60% !important;
margin: 0 auto;
}
"""
with gr.Blocks() as demo:
gr.HTML(f"<style>{css}</style>")
video_state = gr.State(
{
"origin_images": None,
"inference_state": None,
"masks": None,
"painted_images": None,
"video_path": None,
"input_points": [],
"scaled_points": [],
"input_labels": [],
"frame_idx": 0,
"obj_id": 1,
}
)
ref_state = gr.State(
{
"origin_image": None,
"input_points": [],
"input_labels": [],
"scaled_points": [],
"mask": None,
}
)
video_frames_state = gr.State(None)
mask_frames_state = gr.State(None)
gr.Markdown(f"<div style='text-align:center;'>{text}</div>")
with gr.Column():
gr.Markdown("Step1: Upload a Source Video", elem_id="ref_title")
video_input = gr.Video(label="Upload Video", elem_id="my-video1")
gr.Examples(
examples=[
["./examples/1.mp4"],
["./examples/2.mp4"],
["./examples/3.mp4"],
["./examples/4.mp4"],
["./examples/5.mp4"],
["./examples/6.mp4"],
],
inputs=[video_input],
label="You can upload or choose a source video below to retexture.",
elem_id="my-btn2"
)
gr.Markdown("Step2: Extract the First Frame & Click for Segmentation", elem_id="ref_title")
get_info_btn = gr.Button("Extract First Frame", elem_id="my-btn")
image_output = gr.Image(
label="First Frame Segmentation",
interactive=True,
elem_id="my-video",
)
with gr.Row(elem_id="my-btn"):
point_prompt = gr.Radio(
["Positive", "Negative"], label="Click Type", value="Positive"
)
clear_btn = gr.Button("Clear All Clicks")
gr.Markdown("Step3: Track to Get Video Mask", elem_id="ref_title")
with gr.Row(elem_id="my-btn"):
n_frames_slider = gr.Slider(
minimum=1, maximum=81, value=33, step=1, label="Tracking Frames (4N+1)"
)
track_btn = gr.Button("Tracking")
video_output = gr.Video(label="Tracking Result", elem_id="my-video")
gr.Markdown("Step4: Upload a Reference Image & Click for Reference Segmentation", elem_id="ref_title")
ref_image_input = gr.Image(
label="Upload Reference Image", elem_id="ref-image", interactive=True
)
gr.Examples(
examples=[
["./examples/reference_image/1.png"],
["./examples/reference_image/2.png"],
["./examples/reference_image/3.png"],
["./examples/reference_image/4.png"],
["./examples/reference_image/5.png"],
["./examples/reference_image/6.png"],
["./examples/reference_image/7.png"],
["./examples/reference_image/8.png"],
["./examples/reference_image/9.png"],
],
inputs=[ref_image_input],
label="You can upload or choose a reference image below to retexture.",
elem_id="my-btn3"
)
ref_image_display = gr.Image(
label="Reference Mask Segmentation",
elem_id="ref-mask",
interactive=True,
)
with gr.Row(elem_id="my-btn"):
ref_point_prompt = gr.Radio(
["Positive", "Negative"], label="Ref Click Type", value="Positive"
)
ref_clear_btn = gr.Button("Clear Ref Clicks")
gr.Markdown("Step5: Retexture", elem_id="ref_title")
with gr.Column(elem_id="my-btn"):
dilate_radius_slider = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Mask Dilation Radius",
)
inference_steps_slider = gr.Slider(
minimum=10,
maximum=30,
value=20,
step=1,
label="Num Inference Steps",
)
guidance_slider = gr.Slider(
minimum=1.0,
maximum=3.0,
value=1.5,
step=0.1,
label="Guidance Scale",
)
ref_patch_slider = gr.Slider(
minimum=0.05,
maximum=1.0,
value=0.1,
step=0.05,
label="Reference Patch Ratio",
)
fg_threshold_slider = gr.Slider(
minimum=0.7,
maximum=1.0,
value=1.0,
step=0.01,
label="Jigsaw Patches' Foreground Coverage Threshold",
)
seed_slider = gr.Slider(
minimum=0,
maximum=2147483647,
value=42,
step=1,
label="Seed",
)
remove_btn = gr.Button("Retexture", elem_id="my-btn")
with gr.Row(elem_id="mesh-row"):
mesh_video = gr.Video(label="Untextured Object")
ref_image_final = gr.Image(
label="Jigsawed Reference Image",
interactive=False,
)
remove_video = gr.Video(label="Retexture Results", elem_id="my-video")
remove_btn.click(
inference_and_return_video,
inputs=[
dilate_radius_slider,
inference_steps_slider,
guidance_slider,
ref_patch_slider,
fg_threshold_slider,
seed_slider,
video_frames_state,
mask_frames_state,
ref_state,
],
outputs=[remove_video, mesh_video, ref_image_final],
)
get_info_btn.click(
get_video_info,
inputs=[video_input, video_state],
outputs=image_output,
)
image_output.select(
fn=segment_frame,
inputs=[point_prompt, video_state],
outputs=image_output,
)
clear_btn.click(clear_clicks, inputs=video_state, outputs=image_output)
track_btn.click(
track_video,
inputs=[n_frames_slider, video_state],
outputs=[video_output, video_frames_state, mask_frames_state],
)
ref_image_input.change(
set_ref_image,
inputs=[ref_image_input, ref_state],
outputs=ref_image_display,
)
ref_image_display.select(
fn=segment_ref_frame,
inputs=[ref_point_prompt, ref_state],
outputs=ref_image_display,
)
ref_clear_btn.click(
clear_ref_clicks, inputs=ref_state, outputs=ref_image_display
)
demo.launch()