motionfix-demo / app.py
atnikos's picture
small change
6da63fc
raw
history blame
15.5 kB
from calendar import EPOCH
from geometry_utils import diffout2motion
import gradio as gr
import spaces
import torch
import random
import os
from pathlib import Path
import smplx
from body_renderer import get_render
import joblib
# import cv2
# import moderngl
# ctx = moderngl.create_context(standalone=True)
# print(ctx)
access_token_smpl = os.environ.get('HF_SMPL_TOKEN')
os.environ["PYOPENGL_PLATFORM"] = "egl"
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cuda:0' 🤗
DEFAULT_TEXT = "do it slower "
def get_smpl_models():
REPO_ID = 'athn-nik/smpl_models'
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*",
token=access_token_smpl)
WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
<h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
<h3>
<a href="https://is.mpg.de/person/~nathanasiou" target="_blank" rel="noopener noreferrer">Nikos Athanasiou</a><sup>1</sup>,
<a href="https://is.mpg.de/person/acseke" target="_blank" rel="noopener noreferrer">Alpar Cseke</a><sup>1</sup>,
<br>
<a href="https://ps.is.mpg.de/person/mdiomataris" target="_blank" rel="noopener noreferrer">Markos Diomataris</a><sup>1, 3</sup>,
<a href="https://is.mpg.de/person/black" target="_blank" rel="noopener noreferrer">Michael J. Black</a><sup>1</sup>,
<a href="https://imagine.enpc.fr/~varolg/" target="_blank" rel="noopener noreferrer">G&uuml;l Varol</a><sup>2</sup>
</h3>
<h3>
<sup>1</sup>Max Planck Institute for Intelligent Systems, T&uuml;bingen, Germany;
<sup>2</sup>LIGM, &Eacute;cole des Ponts, Univ Gustave Eiffel, CNRS, France,
<sup>3</sup>ETH Z&uuml;rich, Switzerland
</h3>
</div>
<div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
<a href='https://arxiv.org/pdf/2408.00712'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
<a href='https://motionfix.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
<a href='https://www.youtube.com/watch?v=cFa6V6Ua-TY'><img src='https://img.shields.io/badge/YouTube-red?style=flat&logo=youtube&logoColor=white'></a>
</div>
""")
CREDITS=("""<div class="embed_hidden" style="text-align: center;">
<h3>
The renderer of this demo is adapted from the render of
<a href="https://geometry.stanford.edu/projects/humor/" target="_blank" rel="noopener noreferrer">HuMoR</a>
with the help of <a href="https://ps.is.mpg.de/person/trakshit" target="_blank" rel="noopener noreferrer">Tithi Rakshit</a> :)
</h3>
""")
WEB_source = ("""<div class="embed_hidden" style="text-align: center;">
<h1>Pick a motion to edit!</h1>
<h3>
Here you should pick a source motion
<hr class="double">
</h3>
</div>
""")
WEB_target = ("""<div class="embed_hidden" style="text-align: center;">
<h1>Now type the text to edit that motion!</h1>
<h3>
Here you should get the generated motion!
<hr class="double">
</h3>
</div>
""")
@spaces.GPU
def greet(n):
print(zero.device) # <-- 'cuda:0' 🤗
try:
number = float(n)
except ValueError:
return "Invalid input. Please enter a number."
return f"Hello {zero + number} Tensor"
def clear():
return ""
def show_video(input_text, key_to_use):
from normalization import Normalizer
normalizer = Normalizer()
from diffusion import create_diffusion
from text_encoder import ClipTextEncoder
from tmed_denoiser import TMED_denoiser
model_ckpt = download_models()
checkpoint = torch.load(model_ckpt)
motion_to_edit = download_motion_from_dataset(key_to_use)
ds_sample = joblib.load(motion_to_edit)
source_motion_norm = ds_sample['source_feats_norm'].to('cuda')
seqlen_tgt = ds_sample['target_feats_norm'].shape[0]
seqlen_src = ds_sample['source_feats_norm'].shape[0]
# import ipdb; ipdb.set_trace()
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
tmed_denoiser = TMED_denoiser().to('cuda')
tmed_denoiser.load_state_dict(checkpoint, strict=False)
tmed_denoiser.eval()
text_encoder = ClipTextEncoder()
texts_cond = [input_text]
diffusion_process = create_diffusion(timestep_respacing=None,
learn_sigma=False, sigma_small=True,
diffusion_steps=300,
noise_schedule='squaredcos_cap_v2',
predict_xstart=True)
bsz = 1
no_of_texts = len(texts_cond)
texts_cond = ['']*no_of_texts + texts_cond
texts_cond = ['']*no_of_texts + texts_cond
text_emb, text_mask = text_encoder(texts_cond)
cond_emb_motion = source_motion_norm.unsqueeze(0).permute(1, 0, 2)
cond_motion_mask = torch.ones((bsz, seqlen_src),
dtype=bool, device='cuda')
mask_target = torch.ones((bsz, seqlen_tgt),
dtype=bool, device='cuda')
diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
text_mask.to(cond_emb_motion.device),
cond_emb_motion,
cond_motion_mask,
mask_target,
diffusion_process,
init_vec=None,
init_from='noise',
gd_text=2.0,
gd_motion=2.0,
steps_num=300)
edited_motion = diffout2motion(diff_out, normalizer).squeeze()
# import ipdb; ipdb.set_trace()
# aitrenderer = get_renderer()
# SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral')
# edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:],
# trans=edited_motion[..., :3])
SMPL_MODELS_PATH = str(Path(get_smpl_models()))
body_model=smplx.SMPLHLayer(f"{SMPL_MODELS_PATH}/smplh",
model_type='smplh',
gender='neutral',ext='npz')
# run_smpl_fwd_verticesbody_model, body_transl, body_orient, body_pose,
import random
xx = random.randint(1, 1000)
# edited_mot_to_render
from body_renderer import get_render
from transform3d import transform_body_pose
edited_motion_aa = transform_body_pose(edited_motion[:, 3:],
'6d->aa')
if os.path.exists('./output_movie.mp4'):
os.remove('./output_movie.mp4')
fname = get_render(body_model,
[edited_motion[..., :3].detach().cpu()],
[edited_motion_aa[..., :3].detach().cpu()],
[edited_motion_aa[..., 3:].detach().cpu()],
output_path='./output_movie.mp4',
text='', colors=['sky blue'])
# import ipdb; ipdb.set_trace()
# fname = render_motion(AIT_RENDERER, [edited_mot_to_render],
# f"movie_example--{str(xx)}",
# pose_repr='aa',
# color=[color_map['generated']],
# smpl_layer=SMPL_LAYER)
print(fname)
print(os.path.abspath(fname))
return fname
from huggingface_hub import hf_hub_download
def download_models():
REPO_ID = 'athn-nik/example-model'
return hf_hub_download(REPO_ID, filename="tmed_compressed.ckpt")
def download_motion_from_dataset(key_to_dl):
REPO_ID = 'athn-nik/example-model'
from huggingface_hub import snapshot_download
keytodl = key_to_dl
keytodl = '000008'
path_for_ds = snapshot_download(repo_id=REPO_ID,
allow_patterns=f"dataset_inputs/{keytodl}",
token=access_token_smpl)
path_for_ds_sample = path_for_ds + f'/dataset_inputs/{keytodl}.pth.tar'
return path_for_ds_sample
def download_tmr():
REPO_ID = 'athn-nik/example-model'
# return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
token=access_token_smpl)
def download_motionfix():
REPO_ID = 'athn-nik/example-model'
# return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=REPO_ID, allow_patterns="motionfix*",
token=access_token_smpl)
def download_motionfix_dataset():
REPO_ID = 'athn-nik/example-model'
dataset_downloaded_path = hf_hub_download(REPO_ID, filename="tmed_compressed.ckpt")
dataset_dict = joblib.load(dataset_downloaded_path)
return dataset_dict
def download_embeddings():
REPO_ID = 'athn-nik/example-model'
# return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=REPO_ID, allow_patterns="embeddings*",
token=access_token_smpl)
MFIX_p = download_motionfix() + '/motionfix'
SOURCE_MOTS_p = download_embeddings() + '/embeddings'
MFIX_DATASET_DICT = download_motionfix_dataset()
import gradio as gr
def clear():
return ""
def random_source_motion(set_to_pick):
# import ipdb;ipdb.set_trace()
mfix_train, mfix_test = load_motionfix(MFIX_p)
if set_to_pick == 'all':
current_set = mfix_test | mfix_train
elif set_to_pick == 'train':
current_set = mfix_train
elif set_to_pick == 'test':
current_set = mfix_test
import random
random_key = random.choice(list(current_set.keys()))
curvid = current_set[random_key]['motion_a']
text_annot = current_set[random_key]['annotation']
return curvid, text_annot, random_key, text_annot
def retrieve_video(retrieve_text):
tmr_text_encoder = get_tmr_model(download_tmr())
# import ipdb;ipdb.set_trace()
# text_encoded = tmr_text_encoder([retrieve_text])
motion_embeds = None
from gen_utils import read_json
import numpy as np
motion_embeds = torch.load(SOURCE_MOTS_p+'/source_motions_embeddings.pt')
motion_keyids =np.array(read_json(SOURCE_MOTS_p+'/keyids_embeddings.json'))
mfix_train, mfix_test = load_motionfix(MFIX_p)
all_mots = mfix_test | mfix_train
scores = tmr_text_encoder.compute_scores(retrieve_text, embs=motion_embeds)
sorted_idxs = np.argsort(-scores)
best_keyids = motion_keyids[sorted_idxs]
# best_scores = scores[sorted_idxs]
top_mot = best_keyids[0]
curvid = all_mots[top_mot]['motion_a']
text_annot = all_mots[top_mot]['annotation']
return curvid, text_annot
with gr.Blocks(css="""
.gradio-row {
display: flex;
gap: 20px;
}
.gradio-column {
flex: 1;
}
.gradio-container {
display: flex;
flex-direction: column;
gap: 10px;
}
.gradio-button-row {
display: flex;
gap: 10px;
}
.gradio-textbox-row {
display: flex;
gap: 10px;
align-items: center;
}
.gradio-edit-row {
gap: 10px;
align-items: center;
}
.gradio-textbox-with-button {
display: flex;
align-items: center;
}
.gradio-textbox-with-button input {
flex-grow: 1;
}
""") as demo:
gr.Markdown(WEBSITE)
random_key_state = gr.State()
with gr.Row(elem_id="gradio-row"):
with gr.Column(scale=5, elem_id="gradio-column"):
gr.Markdown(WEB_source)
with gr.Row(elem_id="gradio-button-row"):
# iterative_button = gr.Button("Iterative")
# retrieve_button = gr.Button("TMRetrieve")
random_button = gr.Button("Random")
with gr.Row(elem_id="gradio-textbox-row"):
with gr.Column(scale=5, elem_id="gradio-textbox-with-button"):
# retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
# show_label=True, label="Retrieval Text",
# value=DEFAULT_TEXT)
clear_button_retrieval = gr.Button("Clear", scale=0)
with gr.Row(elem_id="gradio-textbox-row"):
suggested_edit_text = gr.Textbox(placeholder="Texts likely to edit the motion:",
show_label=True, label="Suggested Edit Text",
value='')
xxx = 'https://motion-editing.s3.eu-central-1.amazonaws.com/collection_wo_walks_runs/rendered_pairs/011327_120_240-002682_120_240.mp4'
set_to_pick = gr.Radio(['all', 'train', 'test'],
value='all',
label="Set to pick from",
info="Motion will be picked from whole dataset or test or train data.")
# import ipdb; ipdb.set_trace()
retrieved_video_output = gr.Video(label="Retrieved Motion",
# value=xxx,
height=360, width=480)
with gr.Column(scale=5, elem_id="gradio-column"):
gr.Markdown(WEB_target)
with gr.Row(elem_id="gradio-edit-row"):
clear_button_edit = gr.Button("Clear", scale=0)
edit_button = gr.Button("Edit", scale=0)
with gr.Row(elem_id="gradio-textbox-row"):
input_text = gr.Textbox(placeholder="Type the edit text you want:",
show_label=False, label="Input Text",
value=DEFAULT_TEXT)
video_output = gr.Video(label="Generated Video", height=360,
width=480)
def process_and_show_video(input_text, random_key_state):
fname = show_video(input_text, random_key_state)
return fname
def process_and_retrieve_video(input_text):
fname = retrieve_video(input_text)
return fname
from retrieval_loader import get_tmr_model
from dataset_utils import load_motionfix
edit_button.click(process_and_show_video, inputs=[input_text, random_key_state], outputs=video_output)
# retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=[retrieved_video_output, suggested_edit_text])
random_button.click(random_source_motion, inputs=set_to_pick,
outputs=[retrieved_video_output,
suggested_edit_text,
random_key_state,
input_text])
clear_button_edit.click(clear, outputs=input_text)
# clear_button_retrieval.click(clear, outputs=retrieve_text)
gr.Markdown(CREDITS)
demo.launch(share=True)