File size: 8,789 Bytes
d8530c7
8dc009f
 
 
f4be66d
6837c8b
 
 
 
 
 
 
 
 
8dc009f
 
6837c8b
 
7d87cc1
6837c8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8530c7
6837c8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8530c7
6837c8b
8dc009f
 
 
 
f66aca9
 
 
 
 
8dc009f
f4be66d
 
 
6837c8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8530c7
 
 
 
 
6837c8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4be66d
d8530c7
f4be66d
 
6837c8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4be66d
6837c8b
f4be66d
6837c8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4be66d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from geometry_utils import diffout2motion
import gradio as gr
import spaces
import torch
import random
import os
from pathlib import Path 
from aitviewer.headless import HeadlessRenderer
from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
# import cv2
# import moderngl
# ctx = moderngl.create_context(standalone=True)
# print(ctx)
access_token_smpl = os.environ.get('HF_SMPL_TOKEN')

zero = torch.Tensor([0]).cuda()
print(zero.device) # <-- 'cuda:0' 🤗

DEFAULT_TEXT = "A person is "
from aitviewer.models.smpl import SMPLLayer
def get_smpl_models():
    REPO_ID = 'athn-nik/smpl_models'
    from huggingface_hub import snapshot_download
    return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*",
                             token=access_token_smpl)

def get_renderer():
    from aitviewer.headless import HeadlessRenderer
    from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
    smpl_models_path = str(Path(get_smpl_models()))
    AITVIEWER_CONFIG.update_conf({'playback_fps': 30,
                                  'auto_set_floor': True,
                                  'smplx_models': smpl_models_path,
                                  'z_up': True})
    return HeadlessRenderer()



WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
    <h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
    <h3>
        <a href="https://is.mpg.de/person/~nathanasiou" target="_blank" rel="noopener noreferrer">Nikos Athanasiou</a><sup>1</sup>,
        <a href="https://is.mpg.de/person/acseke" target="_blank" rel="noopener noreferrer">Alpar Cseke</a><sup>1</sup>,
        <br>
        <a href="https://ps.is.mpg.de/person/mdiomataris" target="_blank" rel="noopener noreferrer">Markos Diomataris</a><sup>1, 3</sup>,
        <a href="https://is.mpg.de/person/black" target="_blank" rel="noopener noreferrer">Michael J. Black</a><sup>1</sup>,
        <a href="https://imagine.enpc.fr/~varolg/" target="_blank" rel="noopener noreferrer">G&uuml;l Varol</a><sup>2</sup>,
    </h3>
    <h3>
        <sup>1</sup>Max Planck Institute for Intelligent Systems, T&uuml;bingen, Germany;
        <sup>2</sup>LIGM, &Eacute;cole des Ponts, Univ Gustave Eiffel, CNRS, France,
        <sup>3</sup>ETH Z&uuml;rich, Switzerland
    </h3>
</div>
<div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
<a href='https://arxiv.org/abs/'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a> 
<a href='https://arxiv.org/pdf/'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a> 
<a href='https://motionfix.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a> 
<a href='https://youtube.com/'><img src='https://img.shields.io/badge/YouTube-red?style=flat&logo=youtube&logoColor=white'></a> 
</div>
""")

@spaces.GPU
def greet(n):
    print(zero.device) # <-- 'cuda:0' 🤗
    try:
        number = float(n)
    except ValueError:
        return "Invalid input. Please enter a number."
    return f"Hello {zero + number} Tensor"

def clear():
    return ""

def show_video(input_text):
    from normalization import Normalizer
    normalizer = Normalizer()
    from diffusion import create_diffusion
    from text_encoder import ClipTextEncoder
    from tmed_denoiser import TMED_denoiser
    model_ckpt = download_models()
    checkpoint = torch.load(model_ckpt)

    checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
    tmed_denoiser = TMED_denoiser().to('cuda')
    tmed_denoiser.load_state_dict(checkpoint, strict=False)
    tmed_denoiser.eval()
    text_encoder = ClipTextEncoder()
    texts_cond = [input_text]
    
    diffusion_process = create_diffusion(timestep_respacing=None,
                                         learn_sigma=False, sigma_small=True,
                                         diffusion_steps=300,
                                         noise_schedule='squaredcos_cap_v2',
                                         predict_xstart=True)
    bsz = 1
    seqlen_tgt = 180
    no_of_texts = len(texts_cond)
    texts_cond = ['']*no_of_texts + texts_cond
    texts_cond = ['']*no_of_texts + texts_cond
    text_emb, text_mask = text_encoder(texts_cond)

    cond_emb_motion = torch.zeros(seqlen_tgt, bsz, 
                                  512,
                                  device='cuda')
    cond_motion_mask = torch.ones((bsz, seqlen_tgt),
                                  dtype=bool, device='cuda')
    mask_target = torch.ones((bsz, seqlen_tgt),
                             dtype=bool, device='cuda')

    diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
                                                text_mask.to(cond_emb_motion.device),
                                                cond_emb_motion,
                                                cond_motion_mask,
                                                mask_target,
                                                diffusion_process,
                                                init_vec=None,
                                                init_from='noise',
                                                gd_text=4.0,
                                                gd_motion=2.0,
                                                steps_num=300)
    edited_motion = diffout2motion(diff_out, normalizer).squeeze()
    from renderer import render_motion, color_map, pack_to_render
    # aitrenderer = get_renderer()
    AIT_RENDERER = get_renderer()
    SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral')
    edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:],
                                          trans=edited_motion[..., :3])
    import random
    xx = random.randint(1, 1000)
    fname = render_motion(AIT_RENDERER, [edited_mot_to_render],
                          f"movie_example--{str(xx)}",
                          pose_repr='aa',
                          color=[color_map['generated']],
                          smpl_layer=SMPL_LAYER)
    return fname
def retrieve_video(retrieve_text):
    pass

from huggingface_hub import hf_hub_download, hf_hub_url, cached_download

def download_models():
    REPO_ID = 'athn-nik/example-model'
    return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")

def download_tmr():
    REPO_ID = 'athn-nik/example-model'
    # return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
    from huggingface_hub import snapshot_download
    return snapshot_download(repo_id=REPO_ID, allow_patterns="tmr*",
                             token=access_token_smpl)
import gradio as gr

def clear():
    return ""

def random_number():
    return "Random text"

with gr.Blocks() as demo:
    gr.Markdown(WEBSITE)
    
    with gr.Row():
        with gr.Column(scale=8):
            retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
                                       show_label=True, label="Retrieval Text", value=DEFAULT_TEXT)
        with gr.Column(scale=1):
            clear_button_retrieval = gr.Button("Clear Retrieval Text")
    
    with gr.Row():
        with gr.Column(scale=8):
            input_text = gr.Textbox(placeholder="Type the edit text you want:",
                                    show_label=True, label="Input Text", value=DEFAULT_TEXT)
        with gr.Column(scale=1):
            clear_button_edit = gr.Button("Clear Edit Text")

    with gr.Row():
        video_output = gr.Video(label="Generated Video", height=240, width=320)
        retrieved_video_output = gr.Video(label="Retrieved Motion", height=240, width=320)

    with gr.Row():
        edit_button = gr.Button("Edit")
        retrieve_button = gr.Button("Retrieve")

        random_button = gr.Button("Random")

    def process_and_show_video(input_text):
        fname = show_video(input_text)
        return fname

    def process_and_retrieve_video(input_text):
        fname = retrieve_video(input_text)
        return fname
    
    from gen_utils import read_config
    from retrieval_loader import load_model_from_cfg
    from retrieval_loader import get_tmr_model
    tmr = get_tmr_model(download_tmr())
    edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
    retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=retrieved_video_output)
    # import ipdb;ipdb.set_trace()

    clear_button_edit.click(clear, outputs=input_text)
    clear_button_retrieval.click(clear, outputs=retrieve_text)
    random_button.click(random_number, outputs=input_text)

demo.launch()