from geometry_utils import diffout2motion import gradio as gr import spaces import torch import random import os from pathlib import Path from aitviewer.headless import HeadlessRenderer from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG # import cv2 # import moderngl # ctx = moderngl.create_context(standalone=True) # print(ctx) access_token_smpl = os.environ.get('HF_SMPL_TOKEN') zero = torch.Tensor([0]).cuda() print(zero.device) # <-- 'cuda:0' 🤗 DEFAULT_TEXT = "A person is " from aitviewer.models.smpl import SMPLLayer def get_smpl_models(): REPO_ID = 'athn-nik/smpl_models' from huggingface_hub import snapshot_download return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*", token=access_token_smpl) def get_renderer(): from aitviewer.headless import HeadlessRenderer from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG smpl_models_path = str(Path(get_smpl_models())) AITVIEWER_CONFIG.update_conf({'playback_fps': 30, 'auto_set_floor': True, 'smplx_models': smpl_models_path, 'z_up': True}) return HeadlessRenderer() WEBSITE = ("""

MotionFix: Text-Driven 3D Human Motion Editing

Nikos Athanasiou1, Alpar Cseke1,
Markos Diomataris1, 3, Michael J. Black1, Gül Varol2,

1Max Planck Institute for Intelligent Systems, Tübingen, Germany; 2LIGM, École des Ponts, Univ Gustave Eiffel, CNRS, France, 3ETH Zürich, Switzerland

""") @spaces.GPU def greet(n): print(zero.device) # <-- 'cuda:0' 🤗 try: number = float(n) except ValueError: return "Invalid input. Please enter a number." return f"Hello {zero + number} Tensor" def clear(): return "" def show_video(input_text): from normalization import Normalizer normalizer = Normalizer() from diffusion import create_diffusion from text_encoder import ClipTextEncoder from tmed_denoiser import TMED_denoiser model_ckpt = download_models() checkpoint = torch.load(model_ckpt) checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()} tmed_denoiser = TMED_denoiser().to('cuda') tmed_denoiser.load_state_dict(checkpoint, strict=False) tmed_denoiser.eval() text_encoder = ClipTextEncoder() texts_cond = [input_text] diffusion_process = create_diffusion(timestep_respacing=None, learn_sigma=False, sigma_small=True, diffusion_steps=300, noise_schedule='squaredcos_cap_v2', predict_xstart=True) bsz = 1 seqlen_tgt = 180 no_of_texts = len(texts_cond) texts_cond = ['']*no_of_texts + texts_cond texts_cond = ['']*no_of_texts + texts_cond text_emb, text_mask = text_encoder(texts_cond) cond_emb_motion = torch.zeros(seqlen_tgt, bsz, 512, device='cuda') cond_motion_mask = torch.ones((bsz, seqlen_tgt), dtype=bool, device='cuda') mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device='cuda') diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device), text_mask.to(cond_emb_motion.device), cond_emb_motion, cond_motion_mask, mask_target, diffusion_process, init_vec=None, init_from='noise', gd_text=4.0, gd_motion=2.0, steps_num=300) edited_motion = diffout2motion(diff_out, normalizer).squeeze() from renderer import render_motion, color_map, pack_to_render # aitrenderer = get_renderer() AIT_RENDERER = get_renderer() SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral') edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:], trans=edited_motion[..., :3]) import random xx = random.randint(1, 1000) fname = render_motion(AIT_RENDERER, [edited_mot_to_render], f"movie_example--{str(xx)}", pose_repr='aa', color=[color_map['generated']], smpl_layer=SMPL_LAYER) return fname def retrieve_video(retrieve_text): pass from huggingface_hub import hf_hub_download, hf_hub_url, cached_download def download_models(): REPO_ID = 'athn-nik/example-model' return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt") def download_tmr(): REPO_ID = 'athn-nik/example-model' # return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt") from huggingface_hub import snapshot_download return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*", token=access_token_smpl) import gradio as gr def clear(): return "" def random_number(): return "Random text" with gr.Blocks() as demo: gr.Markdown(WEBSITE) with gr.Row(): with gr.Column(scale=8): retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:", show_label=True, label="Retrieval Text", value=DEFAULT_TEXT) with gr.Column(scale=1): clear_button_retrieval = gr.Button("Clear Retrieval Text") with gr.Row(): with gr.Column(scale=8): input_text = gr.Textbox(placeholder="Type the edit text you want:", show_label=True, label="Input Text", value=DEFAULT_TEXT) with gr.Column(scale=1): clear_button_edit = gr.Button("Clear Edit Text") with gr.Row(): video_output = gr.Video(label="Generated Video", height=240, width=320) retrieved_video_output = gr.Video(label="Retrieved Motion", height=240, width=320) with gr.Row(): edit_button = gr.Button("Edit") retrieve_button = gr.Button("Retrieve") random_button = gr.Button("Random") def process_and_show_video(input_text): fname = show_video(input_text) return fname def process_and_retrieve_video(input_text): fname = retrieve_video(input_text) return fname from gen_utils import read_config from retrieval_loader import load_model_from_cfg from retrieval_loader import get_tmr_model tmr = get_tmr_model(download_tmr()) edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output) retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=retrieved_video_output) # import ipdb;ipdb.set_trace() clear_button_edit.click(clear, outputs=input_text) clear_button_retrieval.click(clear, outputs=retrieve_text) random_button.click(random_number, outputs=input_text) demo.launch()