Spaces:
Running
Running
| from geometry_utils import diffout2motion | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import random | |
| import os | |
| from pathlib import Path | |
| from aitviewer.headless import HeadlessRenderer | |
| from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG | |
| # import cv2 | |
| # import moderngl | |
| # ctx = moderngl.create_context(standalone=True) | |
| # print(ctx) | |
| access_token_smpl = os.environ.get('HF_SMPL_TOKEN') | |
| zero = torch.Tensor([0]).cuda() | |
| print(zero.device) # <-- 'cuda:0' 🤗 | |
| DEFAULT_TEXT = "A person is " | |
| from aitviewer.models.smpl import SMPLLayer | |
| def get_smpl_models(): | |
| REPO_ID = 'athn-nik/smpl_models' | |
| from huggingface_hub import snapshot_download | |
| return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*", | |
| token=access_token_smpl) | |
| def get_renderer(): | |
| from aitviewer.headless import HeadlessRenderer | |
| from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG | |
| smpl_models_path = str(Path(get_smpl_models())) | |
| AITVIEWER_CONFIG.update_conf({'playback_fps': 30, | |
| 'auto_set_floor': True, | |
| 'smplx_models': smpl_models_path, | |
| 'z_up': True}) | |
| return HeadlessRenderer() | |
| WEBSITE = ("""<div class="embed_hidden" style="text-align: center;"> | |
| <h1>MotionFix: Text-Driven 3D Human Motion Editing</h1> | |
| <h3> | |
| <a href="https://is.mpg.de/person/~nathanasiou" target="_blank" rel="noopener noreferrer">Nikos Athanasiou</a><sup>1</sup>, | |
| <a href="https://is.mpg.de/person/acseke" target="_blank" rel="noopener noreferrer">Alpar Cseke</a><sup>1</sup>, | |
| <br> | |
| <a href="https://ps.is.mpg.de/person/mdiomataris" target="_blank" rel="noopener noreferrer">Markos Diomataris</a><sup>1, 3</sup>, | |
| <a href="https://is.mpg.de/person/black" target="_blank" rel="noopener noreferrer">Michael J. Black</a><sup>1</sup>, | |
| <a href="https://imagine.enpc.fr/~varolg/" target="_blank" rel="noopener noreferrer">Gül Varol</a><sup>2</sup>, | |
| </h3> | |
| <h3> | |
| <sup>1</sup>Max Planck Institute for Intelligent Systems, Tübingen, Germany; | |
| <sup>2</sup>LIGM, École des Ponts, Univ Gustave Eiffel, CNRS, France, | |
| <sup>3</sup>ETH Zürich, Switzerland | |
| </h3> | |
| </div> | |
| <div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center"> | |
| <a href='https://arxiv.org/abs/'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a> | |
| <a href='https://arxiv.org/pdf/'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a> | |
| <a href='https://motionfix.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a> | |
| <a href='https://youtube.com/'><img src='https://img.shields.io/badge/YouTube-red?style=flat&logo=youtube&logoColor=white'></a> | |
| </div> | |
| """) | |
| def greet(n): | |
| print(zero.device) # <-- 'cuda:0' 🤗 | |
| try: | |
| number = float(n) | |
| except ValueError: | |
| return "Invalid input. Please enter a number." | |
| return f"Hello {zero + number} Tensor" | |
| def clear(): | |
| return "" | |
| def show_video(input_text): | |
| from normalization import Normalizer | |
| normalizer = Normalizer() | |
| from diffusion import create_diffusion | |
| from text_encoder import ClipTextEncoder | |
| from tmed_denoiser import TMED_denoiser | |
| model_ckpt = download_models() | |
| checkpoint = torch.load(model_ckpt) | |
| checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()} | |
| tmed_denoiser = TMED_denoiser().to('cuda') | |
| tmed_denoiser.load_state_dict(checkpoint, strict=False) | |
| tmed_denoiser.eval() | |
| text_encoder = ClipTextEncoder() | |
| texts_cond = [input_text] | |
| diffusion_process = create_diffusion(timestep_respacing=None, | |
| learn_sigma=False, sigma_small=True, | |
| diffusion_steps=300, | |
| noise_schedule='squaredcos_cap_v2', | |
| predict_xstart=True) | |
| bsz = 1 | |
| seqlen_tgt = 180 | |
| no_of_texts = len(texts_cond) | |
| texts_cond = ['']*no_of_texts + texts_cond | |
| texts_cond = ['']*no_of_texts + texts_cond | |
| text_emb, text_mask = text_encoder(texts_cond) | |
| cond_emb_motion = torch.zeros(seqlen_tgt, bsz, | |
| 512, | |
| device='cuda') | |
| cond_motion_mask = torch.ones((bsz, seqlen_tgt), | |
| dtype=bool, device='cuda') | |
| mask_target = torch.ones((bsz, seqlen_tgt), | |
| dtype=bool, device='cuda') | |
| diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device), | |
| text_mask.to(cond_emb_motion.device), | |
| cond_emb_motion, | |
| cond_motion_mask, | |
| mask_target, | |
| diffusion_process, | |
| init_vec=None, | |
| init_from='noise', | |
| gd_text=4.0, | |
| gd_motion=2.0, | |
| steps_num=300) | |
| edited_motion = diffout2motion(diff_out, normalizer).squeeze() | |
| from renderer import render_motion, color_map, pack_to_render | |
| # aitrenderer = get_renderer() | |
| AIT_RENDERER = get_renderer() | |
| SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral') | |
| edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:], | |
| trans=edited_motion[..., :3]) | |
| import random | |
| xx = random.randint(1, 1000) | |
| fname = render_motion(AIT_RENDERER, [edited_mot_to_render], | |
| f"movie_example--{str(xx)}", | |
| pose_repr='aa', | |
| color=[color_map['generated']], | |
| smpl_layer=SMPL_LAYER) | |
| return fname | |
| def retrieve_video(retrieve_text): | |
| pass | |
| from huggingface_hub import hf_hub_download, hf_hub_url, cached_download | |
| def download_models(): | |
| REPO_ID = 'athn-nik/example-model' | |
| return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt") | |
| def download_tmr(): | |
| REPO_ID = 'athn-nik/example-model' | |
| # return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt") | |
| from huggingface_hub import snapshot_download | |
| return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*", | |
| token=access_token_smpl) | |
| import gradio as gr | |
| def clear(): | |
| return "" | |
| def random_number(): | |
| return "Random text" | |
| with gr.Blocks() as demo: | |
| gr.Markdown(WEBSITE) | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:", | |
| show_label=True, label="Retrieval Text", value=DEFAULT_TEXT) | |
| with gr.Column(scale=1): | |
| clear_button_retrieval = gr.Button("Clear Retrieval Text") | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| input_text = gr.Textbox(placeholder="Type the edit text you want:", | |
| show_label=True, label="Input Text", value=DEFAULT_TEXT) | |
| with gr.Column(scale=1): | |
| clear_button_edit = gr.Button("Clear Edit Text") | |
| with gr.Row(): | |
| video_output = gr.Video(label="Generated Video", height=240, width=320) | |
| retrieved_video_output = gr.Video(label="Retrieved Motion", height=240, width=320) | |
| with gr.Row(): | |
| edit_button = gr.Button("Edit") | |
| retrieve_button = gr.Button("Retrieve") | |
| random_button = gr.Button("Random") | |
| def process_and_show_video(input_text): | |
| fname = show_video(input_text) | |
| return fname | |
| def process_and_retrieve_video(input_text): | |
| fname = retrieve_video(input_text) | |
| return fname | |
| from gen_utils import read_config | |
| from retrieval_loader import load_model_from_cfg | |
| from retrieval_loader import get_tmr_model | |
| tmr = get_tmr_model(download_tmr()) | |
| edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output) | |
| retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=retrieved_video_output) | |
| # import ipdb;ipdb.set_trace() | |
| clear_button_edit.click(clear, outputs=input_text) | |
| clear_button_retrieval.click(clear, outputs=retrieve_text) | |
| random_button.click(random_number, outputs=input_text) | |
| demo.launch() |