Spaces:
Running
Running
File size: 8,789 Bytes
d8530c7 8dc009f f4be66d 6837c8b 8dc009f 6837c8b 7d87cc1 6837c8b d8530c7 6837c8b d8530c7 6837c8b 8dc009f f66aca9 8dc009f f4be66d 6837c8b d8530c7 6837c8b f4be66d d8530c7 f4be66d 6837c8b f4be66d 6837c8b f4be66d 6837c8b f4be66d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
from geometry_utils import diffout2motion
import gradio as gr
import spaces
import torch
import random
import os
from pathlib import Path
from aitviewer.headless import HeadlessRenderer
from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
# import cv2
# import moderngl
# ctx = moderngl.create_context(standalone=True)
# print(ctx)
access_token_smpl = os.environ.get('HF_SMPL_TOKEN')
zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cuda:0' 🤗
DEFAULT_TEXT = "A person is "
from aitviewer.models.smpl import SMPLLayer
def get_smpl_models():
REPO_ID = 'athn-nik/smpl_models'
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*",
token=access_token_smpl)
def get_renderer():
from aitviewer.headless import HeadlessRenderer
from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
smpl_models_path = str(Path(get_smpl_models()))
AITVIEWER_CONFIG.update_conf({'playback_fps': 30,
'auto_set_floor': True,
'smplx_models': smpl_models_path,
'z_up': True})
return HeadlessRenderer()
WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
<h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
<h3>
<a href="https://is.mpg.de/person/~nathanasiou" target="_blank" rel="noopener noreferrer">Nikos Athanasiou</a><sup>1</sup>,
<a href="https://is.mpg.de/person/acseke" target="_blank" rel="noopener noreferrer">Alpar Cseke</a><sup>1</sup>,
<br>
<a href="https://ps.is.mpg.de/person/mdiomataris" target="_blank" rel="noopener noreferrer">Markos Diomataris</a><sup>1, 3</sup>,
<a href="https://is.mpg.de/person/black" target="_blank" rel="noopener noreferrer">Michael J. Black</a><sup>1</sup>,
<a href="https://imagine.enpc.fr/~varolg/" target="_blank" rel="noopener noreferrer">Gül Varol</a><sup>2</sup>,
</h3>
<h3>
<sup>1</sup>Max Planck Institute for Intelligent Systems, Tübingen, Germany;
<sup>2</sup>LIGM, École des Ponts, Univ Gustave Eiffel, CNRS, France,
<sup>3</sup>ETH Zürich, Switzerland
</h3>
</div>
<div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
<a href='https://arxiv.org/abs/'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
<a href='https://arxiv.org/pdf/'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a>
<a href='https://motionfix.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
<a href='https://youtube.com/'><img src='https://img.shields.io/badge/YouTube-red?style=flat&logo=youtube&logoColor=white'></a>
</div>
""")
@spaces.GPU
def greet(n):
print(zero.device) # <-- 'cuda:0' 🤗
try:
number = float(n)
except ValueError:
return "Invalid input. Please enter a number."
return f"Hello {zero + number} Tensor"
def clear():
return ""
def show_video(input_text):
from normalization import Normalizer
normalizer = Normalizer()
from diffusion import create_diffusion
from text_encoder import ClipTextEncoder
from tmed_denoiser import TMED_denoiser
model_ckpt = download_models()
checkpoint = torch.load(model_ckpt)
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
tmed_denoiser = TMED_denoiser().to('cuda')
tmed_denoiser.load_state_dict(checkpoint, strict=False)
tmed_denoiser.eval()
text_encoder = ClipTextEncoder()
texts_cond = [input_text]
diffusion_process = create_diffusion(timestep_respacing=None,
learn_sigma=False, sigma_small=True,
diffusion_steps=300,
noise_schedule='squaredcos_cap_v2',
predict_xstart=True)
bsz = 1
seqlen_tgt = 180
no_of_texts = len(texts_cond)
texts_cond = ['']*no_of_texts + texts_cond
texts_cond = ['']*no_of_texts + texts_cond
text_emb, text_mask = text_encoder(texts_cond)
cond_emb_motion = torch.zeros(seqlen_tgt, bsz,
512,
device='cuda')
cond_motion_mask = torch.ones((bsz, seqlen_tgt),
dtype=bool, device='cuda')
mask_target = torch.ones((bsz, seqlen_tgt),
dtype=bool, device='cuda')
diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
text_mask.to(cond_emb_motion.device),
cond_emb_motion,
cond_motion_mask,
mask_target,
diffusion_process,
init_vec=None,
init_from='noise',
gd_text=4.0,
gd_motion=2.0,
steps_num=300)
edited_motion = diffout2motion(diff_out, normalizer).squeeze()
from renderer import render_motion, color_map, pack_to_render
# aitrenderer = get_renderer()
AIT_RENDERER = get_renderer()
SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral')
edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:],
trans=edited_motion[..., :3])
import random
xx = random.randint(1, 1000)
fname = render_motion(AIT_RENDERER, [edited_mot_to_render],
f"movie_example--{str(xx)}",
pose_repr='aa',
color=[color_map['generated']],
smpl_layer=SMPL_LAYER)
return fname
def retrieve_video(retrieve_text):
pass
from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
def download_models():
REPO_ID = 'athn-nik/example-model'
return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
def download_tmr():
REPO_ID = 'athn-nik/example-model'
# return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
token=access_token_smpl)
import gradio as gr
def clear():
return ""
def random_number():
return "Random text"
with gr.Blocks() as demo:
gr.Markdown(WEBSITE)
with gr.Row():
with gr.Column(scale=8):
retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
show_label=True, label="Retrieval Text", value=DEFAULT_TEXT)
with gr.Column(scale=1):
clear_button_retrieval = gr.Button("Clear Retrieval Text")
with gr.Row():
with gr.Column(scale=8):
input_text = gr.Textbox(placeholder="Type the edit text you want:",
show_label=True, label="Input Text", value=DEFAULT_TEXT)
with gr.Column(scale=1):
clear_button_edit = gr.Button("Clear Edit Text")
with gr.Row():
video_output = gr.Video(label="Generated Video", height=240, width=320)
retrieved_video_output = gr.Video(label="Retrieved Motion", height=240, width=320)
with gr.Row():
edit_button = gr.Button("Edit")
retrieve_button = gr.Button("Retrieve")
random_button = gr.Button("Random")
def process_and_show_video(input_text):
fname = show_video(input_text)
return fname
def process_and_retrieve_video(input_text):
fname = retrieve_video(input_text)
return fname
from gen_utils import read_config
from retrieval_loader import load_model_from_cfg
from retrieval_loader import get_tmr_model
tmr = get_tmr_model(download_tmr())
edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=retrieved_video_output)
# import ipdb;ipdb.set_trace()
clear_button_edit.click(clear, outputs=input_text)
clear_button_retrieval.click(clear, outputs=retrieve_text)
random_button.click(random_number, outputs=input_text)
demo.launch() |