motionfix-demo / app.py
atnikos's picture
fixes for examples
1277b3a
raw
history blame
16.7 kB
import os
from pathlib import Path
import gradio as gr
import spaces
import torch
import smplx
import numpy as np
from website import CREDITS, WEB_source, WEB_target, WEBSITE
from download_deps import get_smpl_models, download_models, download_model_config
from download_deps import download_tmr, download_motionfix, download_motionfix_dataset
from download_deps import download_embeddings
import random
# DO NOT initialize CUDA here
DEFAULT_TEXT = "do it slower"
import os
os.environ['PYOPENGL_PLATFORM'] = 'egl'
os.environ['LD_LIBRARY_PATH'] = '/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/nvidia/current:' + os.environ.get('LD_LIBRARY_PATH', '')
# Optional debugging
import subprocess
try:
result = subprocess.run(['ldconfig', '-p'], capture_output=True, text=True)
egl_libs = [line for line in result.stdout.split('\n') if 'EGL' in line]
print("Available EGL libraries:", egl_libs)
except Exception as e:
print(f"Error finding libraries: {e}")
# Example videos
example_videos = [
"./examples/000652_0_120.mp4", # Replace with actual video paths
"./examples/000652_0_120.mp4", # Replace with actual video paths
"./examples/000652_0_120.mp4", # Replace with actual video paths
"./examples/000652_0_120.mp4", # Replace with actual video paths
]
# Example videos
example_keys = [
"000091", # Replace with actual video paths
"000091", # Replace with actual video paths
"000091", # Replace with actual video paths
"000091", # Replace with actual video paths
]
# Example videos
example_texts = [
"need to use the opposite leg", # Replace with actual video paths
"need to use the opposite leg2", # Replace with actual video paths
"need to use the opposite leg3", # Replace with actual video paths
"need to use the opposite leg4", # Replace with actual video paths
]
example_video_outputs = [gr.Video(label=f"Example {i+1}",
value=example_videos[i])
for i in range(4)]
class MotionEditor:
def __init__(self):
# Don't initialize any CUDA components in __init__
self.is_initialized = False
self.MFIX_p = download_motionfix() + '/motionfix'
self.SOURCE_MOTS_p = download_embeddings() + '/embeddings'
self.MFIX_DATASET_DICT = download_motionfix_dataset()
self.model_ckpt_path = download_models()
self.model_config_feats = download_model_config()
@spaces.GPU
def initialize_if_needed(self):
"""Initialize models only when needed, within a GPU-decorated function"""
if self.is_initialized:
return
from normalization import Normalizer
from diffusion import create_diffusion
from text_encoder import ClipTextEncoder
from tmed_denoiser import TMED_denoiser
# Initialize components
self.device = torch.device('cuda')
self.normalizer = Normalizer()
self.text_encoder = ClipTextEncoder()
# Load models and configs
model_ckpt = self.model_ckpt_path
self.infeats = self.model_config_feats
checkpoint = torch.load(model_ckpt, map_location=self.device)
checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
# Setup denoiser
self.tmed_denoiser = TMED_denoiser().to(self.device)
self.tmed_denoiser.load_state_dict(checkpoint, strict=False)
self.tmed_denoiser.eval()
# Setup diffusion
self.diffusion = create_diffusion(
timestep_respacing=None,
learn_sigma=False,
sigma_small=True,
diffusion_steps=300,
noise_schedule='squaredcos_cap_v2',
predict_xstart=True
)
# Setup SMPL model
smpl_models_path = str(Path(get_smpl_models()))
self.body_model = smplx.SMPLHLayer(
f"{smpl_models_path}/smplh",
model_type='smplh',
gender='neutral',
ext='npz'
)
self.is_initialized = True
@spaces.GPU(duration=360)
def process_motion(self, input_text, key_to_use):
"""Main processing function, GPU-decorated"""
self.initialize_if_needed()
# Load dataset sample
ds_sample = self.MFIX_DATASET_DICT[key_to_use]
# Process features
data_dict = self.process_features(ds_sample)
source_motion_norm, target_motion_norm = self.normalize_motions(data_dict)
source_motion = self.denormalize_motion(source_motion_norm)
# Generate edited motion
edited_motion = self.generate_edited_motion(
input_text,
source_motion_norm,
target_motion_norm
)
# Render result
return self.render_result(edited_motion, source_motion)
def process_features(self, ds_sample):
"""Process features - called from within GPU-decorated function"""
from feature_extractor import FEAT_GET_METHODS
data_dict = {}
for feat in self.infeats:
data_dict[f'{feat}_source'] = FEAT_GET_METHODS[feat](
ds_sample['motion_source']
)[None].to(self.device)
data_dict[f'{feat}_target'] = FEAT_GET_METHODS[feat](
ds_sample['motion_target']
)[None].to(self.device)
return data_dict
def normalize_motions(self, data_dict):
"""Normalize motions - called from within GPU-decorated function"""
batch = self.normalizer.norm_and_cat(data_dict, self.infeats)
return batch['source'], batch['target']
def generate_edited_motion(self, input_text, source_motion, target_motion):
"""Generate edited motion - called from within GPU-decorated function"""
# Encode text
texts_cond = [''] * 2 + [input_text]
text_emb, text_mask = self.text_encoder(texts_cond)
# Setup masks
bsz = 1
seqlen_src = source_motion.shape[0]
seqlen_tgt = target_motion.shape[0]
cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device=self.device)
mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device=self.device)
# Generate diffusion output
diff_out = self.tmed_denoiser._diffusion_reverse(
text_emb.to(self.device),
text_mask.to(self.device),
source_motion,
cond_motion_mask,
mask_target,
self.diffusion,
init_vec=None,
init_from='noise',
gd_text=2.0,
gd_motion=2.0,
steps_num=300
)
return self.denormalize_motion(diff_out)
def denormalize_motion(self, diff_out):
"""Denormalize motion - called from within GPU-decorated function"""
from geometry_utils import diffout2motion
return diffout2motion(diff_out.permute(1, 0, 2), self.normalizer).squeeze()
def render_result(self, edited_motion, source_motion):
"""Render result - called from within GPU-decorated function"""
from body_renderer import get_render
from transform3d import transform_body_pose, rotate_body_degrees
# import ipdb; ipdb.set_trace()
# Transform motions
edited_motion_transformed = self.transform_motion(edited_motion)
source_motion_transformed = self.transform_motion(source_motion)
# Render video
if os.path.exists('./output_movie.mp4'):
os.remove('./output_movie.mp4')
return get_render(
self.body_model,
[edited_motion_transformed['trans'], source_motion_transformed['trans']],
[edited_motion_transformed['rots_init'], source_motion_transformed['rots_init']],
[edited_motion_transformed['rots_rest'], source_motion_transformed['rots_rest']],
output_path='./output_movie.mp4',
text='',
colors=['sky blue', 'red']
)
def transform_motion(self, motion):
"""Transform motion - called from within GPU-decorated function"""
from transform3d import transform_body_pose, rotate_body_degrees
motion_aa = transform_body_pose(motion[:, 3:], '6d->aa')
trans = motion[..., :3].detach().cpu()
rots_aa = motion_aa.detach().cpu()
rots_rotated, trans_rotated = rotate_body_degrees(
transform_body_pose(rots_aa, 'aa->rot'),
trans,
offset=np.pi
)
rots_rotated_aa = transform_body_pose(rots_rotated, 'rot->aa')
return {
'trans': trans_rotated,
'rots_init': rots_rotated_aa[:, 0],
'rots_rest': rots_rotated_aa[:, 1:]
}
# Gradio Interface
def create_gradio_interface():
editor = MotionEditor()
@spaces.GPU
def process_and_show_video(input_text, random_key_state):
return editor.process_motion(input_text, random_key_state)
def random_source_motion(set_to_pick):
from dataset_utils import load_motionfix
mfix_train, mfix_test = load_motionfix(editor.MFIX_p)
current_set = {
'all': mfix_test | mfix_train,
'train': mfix_train,
'test': mfix_test
}[set_to_pick]
random_key = random.choice(list(current_set.keys()))
motion = current_set[random_key]['motion_a']
text_annot = current_set[random_key]['annotation']
return gr.update(value=motion,
visible=True), text_annot, random_key, text_annot
def clear():
return ""
# Gradio UI
with gr.Blocks(css=CUSTOM_CSS) as demo:
gr.HTML(WEBSITE)
random_key_state = gr.State()
with gr.Row():
with gr.Column(scale=5):
gr.HTML(WEB_source)
with gr.Row():
random_button = gr.Button("Random", scale=0)
clear_button_retrieval = gr.Button("Clear", scale=0)
# Example videos grid with buttons
suggested_edit_text = gr.Textbox(
placeholder="Texts likely to edit the motion:",
label="Suggested Edit Text",
value=''
)
set_to_pick = gr.Radio(
['all', 'train', 'test'],
value='all',
label="Set to pick from"
)
retrieved_video_output = gr.Video(
label="Retrieved Motion",
height=360,
width=480,
visible=False # Initially hidden
)
gr.Markdown("### Examples")
with gr.Row():
# First example
with gr.Column():
gr.Video(value=example_videos[0],
height=180,width=240,
label="Example 1")
example_button1 = gr.Button("Select Example 1",
size='sm', elem_classes=["fit-text"])
# Second example
with gr.Column():
gr.Video(value=example_videos[1],
height=180,width=240,
label="Example 2")
example_button2 = gr.Button("Select Example 2",
elem_classes=["fit-text"])
with gr.Row():
# Third example
with gr.Column():
gr.Video(value=example_videos[2],
height=180,width=240,
label="Example 3")
example_button3 = gr.Button("Select Example 3",
elem_classes=["fit-text"])
# Fourth example
with gr.Column():
gr.Video(value=example_videos[3],
height=180,width=240,
label="Example 4")
example_button4 = gr.Button("Select Example 4",
elem_classes=["fit-text"])
with gr.Column(scale=5):
gr.HTML(WEB_target)
with gr.Row():
clear_button_edit = gr.Button("Clear", scale=0)
edit_button = gr.Button("Edit", scale=0)
input_text = gr.Textbox(
placeholder="Type the edit text you want:",
label="Input Text",
value=DEFAULT_TEXT
)
video_output = gr.Video(
label="Generated Video",
height=360,
width=480
)
# Event handlers
edit_button.click(
process_and_show_video,
inputs=[input_text, random_key_state],
outputs=video_output
)
random_button.click(
random_source_motion,
inputs=set_to_pick,
outputs=[
retrieved_video_output,
suggested_edit_text,
random_key_state,
input_text
]
)
# def load_example_video(example_path):
# # motion = current_set[random_key]['motion_a']
# # text_annot = current_set[random_key]['annotation']
# import ipdb; ipdb.set_trace()
# return gr.update(value=example_path, visible=True)
def load_example(example_video, example_key, example_text):
# Update all outputs
return (
gr.update(value=example_video, visible=True), # Update video output
example_text, # Update suggested edit text
example_key, # Update random key state
example_text # Update input text
)
example_button1.click(
fn=lambda: load_example(example_videos[0], example_keys[0], example_texts[0]),
inputs=None,
outputs=[
retrieved_video_output,
suggested_edit_text,
random_key_state,
input_text
]
)
example_button2.click(
fn=lambda: load_example(example_videos[1], example_keys[1], example_texts[1]),
inputs=None,
outputs=[
retrieved_video_output,
suggested_edit_text,
random_key_state,
input_text
]
)
example_button3.click(
fn=lambda: load_example(example_videos[2], example_keys[2], example_texts[2]),
inputs=None,
outputs=[
retrieved_video_output,
suggested_edit_text,
random_key_state,
input_text
]
)
example_button4.click(
fn=lambda: load_example(example_videos[3], example_keys[3], example_texts[3]),
inputs=None,
outputs=[
retrieved_video_output,
suggested_edit_text,
random_key_state,
input_text
]
)
clear_button_edit.click(clear, outputs=input_text)
clear_button_retrieval.click(clear, outputs=suggested_edit_text)
gr.Markdown(CREDITS)
return demo
# Constants
CUSTOM_CSS = """
.gradio-row { display: flex; gap: 20px; }
.gradio-column { flex: 1; }
.gradio-container { display: flex; flex-direction: column; gap: 10px; }
.gradio-button-row { display: flex; gap: 10px; }
.gradio-textbox-row { display: flex; gap: 10px; align-items: center; }
.gradio-edit-row { gap: 10px; align-items: center; }
.gradio-textbox-with-button { display: flex; align-items: center; }
.gradio-textbox-with-button input { flex-grow: 1; }
button.fit-text {
width: auto; /* Automatically adjusts to the text length */
padding: 10px 20px; /* Adjust padding for a better look */
font-size: 14px; /* Control font size */
text-align: center; /* Center the text */
margin: 0 auto; /* Center the button horizontally */
display: inline-block; /* Prevent it from stretching */
}
"""
if __name__ == "__main__":
demo = create_gradio_interface()
demo.launch(share=True)