motionfix-demo / app.py
atnikos's picture
fix retrieval placeholders
6837c8b
raw
history blame
8.79 kB
from geometry_utils import diffout2motion
import gradio as gr
import spaces
import torch
import random
import os
from pathlib import Path
from aitviewer.headless import HeadlessRenderer
from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
# import cv2
# import moderngl
# ctx = moderngl.create_context(standalone=True)
# print(ctx)
access_token_smpl = os.environ.get('HF_SMPL_TOKEN')
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cuda:0' 🤗
DEFAULT_TEXT = "A person is "
from aitviewer.models.smpl import SMPLLayer
def get_smpl_models():
REPO_ID = 'athn-nik/smpl_models'
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*",
token=access_token_smpl)
def get_renderer():
from aitviewer.headless import HeadlessRenderer
from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
smpl_models_path = str(Path(get_smpl_models()))
AITVIEWER_CONFIG.update_conf({'playback_fps': 30,
'auto_set_floor': True,
'smplx_models': smpl_models_path,
'z_up': True})
return HeadlessRenderer()
WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
<h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
<h3>
<a href="https://is.mpg.de/person/~nathanasiou" target="_blank" rel="noopener noreferrer">Nikos Athanasiou</a><sup>1</sup>,
<a href="https://is.mpg.de/person/acseke" target="_blank" rel="noopener noreferrer">Alpar Cseke</a><sup>1</sup>,
<br>
<a href="https://ps.is.mpg.de/person/mdiomataris" target="_blank" rel="noopener noreferrer">Markos Diomataris</a><sup>1, 3</sup>,
<a href="https://is.mpg.de/person/black" target="_blank" rel="noopener noreferrer">Michael J. Black</a><sup>1</sup>,
<a href="https://imagine.enpc.fr/~varolg/" target="_blank" rel="noopener noreferrer">G&uuml;l Varol</a><sup>2</sup>,
</h3>
<h3>
<sup>1</sup>Max Planck Institute for Intelligent Systems, T&uuml;bingen, Germany;
<sup>2</sup>LIGM, &Eacute;cole des Ponts, Univ Gustave Eiffel, CNRS, France,
<sup>3</sup>ETH Z&uuml;rich, Switzerland
</h3>
</div>
<div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
<a href='https://arxiv.org/abs/'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
<a href='https://arxiv.org/pdf/'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a>
<a href='https://motionfix.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
<a href='https://youtube.com/'><img src='https://img.shields.io/badge/YouTube-red?style=flat&logo=youtube&logoColor=white'></a>
</div>
""")
@spaces.GPU
def greet(n):
print(zero.device) # <-- 'cuda:0' 🤗
try:
number = float(n)
except ValueError:
return "Invalid input. Please enter a number."
return f"Hello {zero + number} Tensor"
def clear():
return ""
def show_video(input_text):
from normalization import Normalizer
normalizer = Normalizer()
from diffusion import create_diffusion
from text_encoder import ClipTextEncoder
from tmed_denoiser import TMED_denoiser
model_ckpt = download_models()
checkpoint = torch.load(model_ckpt)
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
tmed_denoiser = TMED_denoiser().to('cuda')
tmed_denoiser.load_state_dict(checkpoint, strict=False)
tmed_denoiser.eval()
text_encoder = ClipTextEncoder()
texts_cond = [input_text]
diffusion_process = create_diffusion(timestep_respacing=None,
learn_sigma=False, sigma_small=True,
diffusion_steps=300,
noise_schedule='squaredcos_cap_v2',
predict_xstart=True)
bsz = 1
seqlen_tgt = 180
no_of_texts = len(texts_cond)
texts_cond = ['']*no_of_texts + texts_cond
texts_cond = ['']*no_of_texts + texts_cond
text_emb, text_mask = text_encoder(texts_cond)
cond_emb_motion = torch.zeros(seqlen_tgt, bsz,
512,
device='cuda')
cond_motion_mask = torch.ones((bsz, seqlen_tgt),
dtype=bool, device='cuda')
mask_target = torch.ones((bsz, seqlen_tgt),
dtype=bool, device='cuda')
diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
text_mask.to(cond_emb_motion.device),
cond_emb_motion,
cond_motion_mask,
mask_target,
diffusion_process,
init_vec=None,
init_from='noise',
gd_text=4.0,
gd_motion=2.0,
steps_num=300)
edited_motion = diffout2motion(diff_out, normalizer).squeeze()
from renderer import render_motion, color_map, pack_to_render
# aitrenderer = get_renderer()
AIT_RENDERER = get_renderer()
SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral')
edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:],
trans=edited_motion[..., :3])
import random
xx = random.randint(1, 1000)
fname = render_motion(AIT_RENDERER, [edited_mot_to_render],
f"movie_example--{str(xx)}",
pose_repr='aa',
color=[color_map['generated']],
smpl_layer=SMPL_LAYER)
return fname
def retrieve_video(retrieve_text):
pass
from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
def download_models():
REPO_ID = 'athn-nik/example-model'
return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
def download_tmr():
REPO_ID = 'athn-nik/example-model'
# return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
token=access_token_smpl)
import gradio as gr
def clear():
return ""
def random_number():
return "Random text"
with gr.Blocks() as demo:
gr.Markdown(WEBSITE)
with gr.Row():
with gr.Column(scale=8):
retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
show_label=True, label="Retrieval Text", value=DEFAULT_TEXT)
with gr.Column(scale=1):
clear_button_retrieval = gr.Button("Clear Retrieval Text")
with gr.Row():
with gr.Column(scale=8):
input_text = gr.Textbox(placeholder="Type the edit text you want:",
show_label=True, label="Input Text", value=DEFAULT_TEXT)
with gr.Column(scale=1):
clear_button_edit = gr.Button("Clear Edit Text")
with gr.Row():
video_output = gr.Video(label="Generated Video", height=240, width=320)
retrieved_video_output = gr.Video(label="Retrieved Motion", height=240, width=320)
with gr.Row():
edit_button = gr.Button("Edit")
retrieve_button = gr.Button("Retrieve")
random_button = gr.Button("Random")
def process_and_show_video(input_text):
fname = show_video(input_text)
return fname
def process_and_retrieve_video(input_text):
fname = retrieve_video(input_text)
return fname
from gen_utils import read_config
from retrieval_loader import load_model_from_cfg
from retrieval_loader import get_tmr_model
tmr = get_tmr_model(download_tmr())
edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=retrieved_video_output)
# import ipdb;ipdb.set_trace()
clear_button_edit.click(clear, outputs=input_text)
clear_button_retrieval.click(clear, outputs=retrieve_text)
random_button.click(random_number, outputs=input_text)
demo.launch()