atnikos commited on
Commit
517683d
·
1 Parent(s): 11d93b7

basic functionalities working

Browse files
.gitignore CHANGED
@@ -1,3 +1,5 @@
 
 
1
  .err
2
  *.out
3
  /cluster_scripts
 
1
+ .cursorignore
2
+ #*.mp4
3
  .err
4
  *.out
5
  /cluster_scripts
app.py CHANGED
@@ -5,37 +5,27 @@ import torch
5
  import random
6
  import os
7
  from pathlib import Path
8
- from aitviewer.headless import HeadlessRenderer
9
- from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
 
10
  # import cv2
11
  # import moderngl
12
  # ctx = moderngl.create_context(standalone=True)
13
  # print(ctx)
14
  access_token_smpl = os.environ.get('HF_SMPL_TOKEN')
 
15
 
16
  zero = torch.Tensor([0]).cuda()
17
  print(zero.device) # <-- 'cuda:0' 🤗
18
 
19
- DEFAULT_TEXT = "A person is "
20
 
21
- from aitviewer.models.smpl import SMPLLayer
22
  def get_smpl_models():
23
  REPO_ID = 'athn-nik/smpl_models'
24
  from huggingface_hub import snapshot_download
25
  return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*",
26
  token=access_token_smpl)
27
 
28
- def get_renderer():
29
- from aitviewer.headless import HeadlessRenderer
30
- from aitviewer.configuration import CONFIG as AITVIEWER_CONFIG
31
- smpl_models_path = str(Path(get_smpl_models()))
32
- AITVIEWER_CONFIG.update_conf({'playback_fps': 30,
33
- 'auto_set_floor': True,
34
- 'smplx_models': smpl_models_path,
35
- 'z_up': True})
36
- return HeadlessRenderer()
37
-
38
-
39
  WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
40
  <h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
41
  <h3>
@@ -53,13 +43,19 @@ WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
53
  </h3>
54
  </div>
55
  <div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
56
- <a href='https://arxiv.org/abs/'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
57
- <a href='https://arxiv.org/pdf/'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a>
58
  <a href='https://motionfix.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
59
- <a href='https://youtube.com/'><img src='https://img.shields.io/badge/YouTube-red?style=flat&logo=youtube&logoColor=white'></a>
60
  </div>
61
  """)
62
 
 
 
 
 
 
 
 
63
 
64
  WEB_source = ("""<div class="embed_hidden" style="text-align: center;">
65
  <h1>Pick a motion to edit!</h1>
@@ -90,7 +86,7 @@ def greet(n):
90
  def clear():
91
  return ""
92
 
93
- def show_video(input_text):
94
  from normalization import Normalizer
95
  normalizer = Normalizer()
96
  from diffusion import create_diffusion
@@ -98,7 +94,12 @@ def show_video(input_text):
98
  from tmed_denoiser import TMED_denoiser
99
  model_ckpt = download_models()
100
  checkpoint = torch.load(model_ckpt)
101
-
 
 
 
 
 
102
  checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
103
  tmed_denoiser = TMED_denoiser().to('cuda')
104
  tmed_denoiser.load_state_dict(checkpoint, strict=False)
@@ -112,20 +113,16 @@ def show_video(input_text):
112
  noise_schedule='squaredcos_cap_v2',
113
  predict_xstart=True)
114
  bsz = 1
115
- seqlen_tgt = 180
116
  no_of_texts = len(texts_cond)
117
  texts_cond = ['']*no_of_texts + texts_cond
118
  texts_cond = ['']*no_of_texts + texts_cond
119
  text_emb, text_mask = text_encoder(texts_cond)
120
 
121
- cond_emb_motion = torch.zeros(seqlen_tgt, bsz,
122
- 512,
123
- device='cuda')
124
- cond_motion_mask = torch.ones((bsz, seqlen_tgt),
125
  dtype=bool, device='cuda')
126
  mask_target = torch.ones((bsz, seqlen_tgt),
127
  dtype=bool, device='cuda')
128
-
129
  diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
130
  text_mask.to(cond_emb_motion.device),
131
  cond_emb_motion,
@@ -134,31 +131,68 @@ def show_video(input_text):
134
  diffusion_process,
135
  init_vec=None,
136
  init_from='noise',
137
- gd_text=4.0,
138
  gd_motion=2.0,
139
  steps_num=300)
140
  edited_motion = diffout2motion(diff_out, normalizer).squeeze()
141
- from renderer import render_motion, color_map, pack_to_render
142
  # aitrenderer = get_renderer()
143
- AIT_RENDERER = get_renderer()
144
- SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral')
145
- edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:],
146
- trans=edited_motion[..., :3])
 
 
 
 
 
 
147
  import random
148
  xx = random.randint(1, 1000)
149
- fname = render_motion(AIT_RENDERER, [edited_mot_to_render],
150
- f"movie_example--{str(xx)}",
151
- pose_repr='aa',
152
- color=[color_map['generated']],
153
- smpl_layer=SMPL_LAYER)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  return fname
155
 
156
- from huggingface_hub import hf_hub_download, hf_hub_url, cached_download
157
 
158
  def download_models():
159
  REPO_ID = 'athn-nik/example-model'
160
  return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
161
 
 
 
 
 
 
 
 
 
 
 
 
162
  def download_tmr():
163
  REPO_ID = 'athn-nik/example-model'
164
  # return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
@@ -202,7 +236,7 @@ def random_source_motion(set_to_pick):
202
  random_key = random.choice(list(current_set.keys()))
203
  curvid = current_set[random_key]['motion_a']
204
  text_annot = current_set[random_key]['annotation']
205
- return curvid, text_annot
206
 
207
  def retrieve_video(retrieve_text):
208
  tmr_text_encoder = get_tmr_model(download_tmr())
@@ -227,6 +261,7 @@ def retrieve_video(retrieve_text):
227
  text_annot = all_mots[top_mot]['annotation']
228
  return curvid, text_annot
229
 
 
230
  with gr.Blocks(css="""
231
  .gradio-row {
232
  display: flex;
@@ -262,20 +297,21 @@ with gr.Blocks(css="""
262
  }
263
  """) as demo:
264
  gr.Markdown(WEBSITE)
265
-
 
266
  with gr.Row(elem_id="gradio-row"):
267
- with gr.Column(scale=7, elem_id="gradio-column"):
268
  gr.Markdown(WEB_source)
269
  with gr.Row(elem_id="gradio-button-row"):
270
- iterative_button = gr.Button("Iterative")
271
- retrieve_button = gr.Button("TMRetrieve")
272
  random_button = gr.Button("Random")
273
 
274
  with gr.Row(elem_id="gradio-textbox-row"):
275
  with gr.Column(scale=5, elem_id="gradio-textbox-with-button"):
276
- retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
277
- show_label=True, label="Retrieval Text",
278
- value=DEFAULT_TEXT)
279
  clear_button_retrieval = gr.Button("Clear", scale=0)
280
 
281
  with gr.Row(elem_id="gradio-textbox-row"):
@@ -288,10 +324,11 @@ with gr.Blocks(css="""
288
  value='all',
289
  label="Set to pick from",
290
  info="Motion will be picked from whole dataset or test or train data.")
291
-
292
  retrieved_video_output = gr.Video(label="Retrieved Motion",
293
- value=xxx,
294
  height=360, width=480)
 
295
 
296
  with gr.Column(scale=5, elem_id="gradio-column"):
297
  gr.Markdown(WEB_target)
@@ -307,8 +344,8 @@ with gr.Blocks(css="""
307
  video_output = gr.Video(label="Generated Video", height=360,
308
  width=480)
309
 
310
- def process_and_show_video(input_text):
311
- fname = show_video(input_text)
312
  return fname
313
 
314
  def process_and_retrieve_video(input_text):
@@ -318,10 +355,15 @@ with gr.Blocks(css="""
318
  from retrieval_loader import get_tmr_model
319
  from dataset_utils import load_motionfix
320
 
321
- edit_button.click(process_and_show_video, inputs=input_text, outputs=video_output)
322
- retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=[retrieved_video_output, suggested_edit_text])
323
- random_button.click(random_source_motion, inputs=set_to_pick, outputs=[retrieved_video_output, suggested_edit_text])
 
 
 
 
324
  clear_button_edit.click(clear, outputs=input_text)
325
- clear_button_retrieval.click(clear, outputs=retrieve_text)
 
326
 
327
- demo.launch()
 
5
  import random
6
  import os
7
  from pathlib import Path
8
+ import smplx
9
+ from body_renderer import get_render
10
+ import joblib
11
  # import cv2
12
  # import moderngl
13
  # ctx = moderngl.create_context(standalone=True)
14
  # print(ctx)
15
  access_token_smpl = os.environ.get('HF_SMPL_TOKEN')
16
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
17
 
18
  zero = torch.Tensor([0]).cuda()
19
  print(zero.device) # <-- 'cuda:0' 🤗
20
 
21
+ DEFAULT_TEXT = "do it slower "
22
 
 
23
  def get_smpl_models():
24
  REPO_ID = 'athn-nik/smpl_models'
25
  from huggingface_hub import snapshot_download
26
  return snapshot_download(repo_id=REPO_ID, allow_patterns="smplh*",
27
  token=access_token_smpl)
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  WEBSITE = ("""<div class="embed_hidden" style="text-align: center;">
30
  <h1>MotionFix: Text-Driven 3D Human Motion Editing</h1>
31
  <h3>
 
43
  </h3>
44
  </div>
45
  <div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
46
+ <a href='https://arxiv.org/pdf/2408.00712'><img src='https://img.shields.io/badge/Arxiv-2405.20340-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
 
47
  <a href='https://motionfix.is.tue.mpg.de'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
48
+ <a href='https://www.youtube.com/watch?v=cFa6V6Ua-TY'><img src='https://img.shields.io/badge/YouTube-red?style=flat&logo=youtube&logoColor=white'></a>
49
  </div>
50
  """)
51
 
52
+ CREDITS=("""<div class="embed_hidden" style="text-align: center;">
53
+ <h3>
54
+ The renderer of this demo is adapted from the render of
55
+ <a href="https://geometry.stanford.edu/projects/humor/" target="_blank" rel="noopener noreferrer">HuMoR</a>
56
+ with the help of <a href="https://ps.is.mpg.de/person/trakshit" target="_blank" rel="noopener noreferrer">Tithi Rakshit</a> :)
57
+ </h3>
58
+ """)
59
 
60
  WEB_source = ("""<div class="embed_hidden" style="text-align: center;">
61
  <h1>Pick a motion to edit!</h1>
 
86
  def clear():
87
  return ""
88
 
89
+ def show_video(input_text, key_to_use):
90
  from normalization import Normalizer
91
  normalizer = Normalizer()
92
  from diffusion import create_diffusion
 
94
  from tmed_denoiser import TMED_denoiser
95
  model_ckpt = download_models()
96
  checkpoint = torch.load(model_ckpt)
97
+ motion_to_edit = download_motion_from_dataset(key_to_use)
98
+ ds_sample = joblib.load(motion_to_edit)
99
+ source_motion_norm = ds_sample['source_feats_norm'].to('cuda')
100
+ seqlen_tgt = ds_sample['target_feats_norm'].shape[0]
101
+ seqlen_src = ds_sample['source_feats_norm'].shape[0]
102
+ # import ipdb; ipdb.set_trace()
103
  checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
104
  tmed_denoiser = TMED_denoiser().to('cuda')
105
  tmed_denoiser.load_state_dict(checkpoint, strict=False)
 
113
  noise_schedule='squaredcos_cap_v2',
114
  predict_xstart=True)
115
  bsz = 1
 
116
  no_of_texts = len(texts_cond)
117
  texts_cond = ['']*no_of_texts + texts_cond
118
  texts_cond = ['']*no_of_texts + texts_cond
119
  text_emb, text_mask = text_encoder(texts_cond)
120
 
121
+ cond_emb_motion = source_motion_norm.unsqueeze(0).permute(1, 0, 2)
122
+ cond_motion_mask = torch.ones((bsz, seqlen_src),
 
 
123
  dtype=bool, device='cuda')
124
  mask_target = torch.ones((bsz, seqlen_tgt),
125
  dtype=bool, device='cuda')
 
126
  diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
127
  text_mask.to(cond_emb_motion.device),
128
  cond_emb_motion,
 
131
  diffusion_process,
132
  init_vec=None,
133
  init_from='noise',
134
+ gd_text=2.0,
135
  gd_motion=2.0,
136
  steps_num=300)
137
  edited_motion = diffout2motion(diff_out, normalizer).squeeze()
138
+ import ipdb; ipdb.set_trace()
139
  # aitrenderer = get_renderer()
140
+ # SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral')
141
+ # edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:],
142
+ # trans=edited_motion[..., :3])
143
+
144
+ SMPL_MODELS_PATH = str(Path(get_smpl_models()))
145
+ body_model=smplx.SMPLHLayer(f"{SMPL_MODELS_PATH}/smplh",
146
+ model_type='smplh',
147
+ gender='neutral',ext='npz')
148
+
149
+ # run_smpl_fwd_verticesbody_model, body_transl, body_orient, body_pose,
150
  import random
151
  xx = random.randint(1, 1000)
152
+ # edited_mot_to_render
153
+ from body_renderer import get_render
154
+ from transform3d import transform_body_pose
155
+ edited_motion_aa = transform_body_pose(edited_motion[:, 3:],
156
+ '6d->aa')
157
+ if os.path.exists('./output_movie.mp4'):
158
+ os.remove('./output_movie.mp4')
159
+ fname = get_render(body_model,
160
+ [edited_motion[..., :3].detach().cpu()],
161
+ [edited_motion_aa[..., :3].detach().cpu()],
162
+ [edited_motion_aa[..., 3:].detach().cpu()],
163
+ output_path='./output_movie.mp4',
164
+ text='', colors=['sky blue'])
165
+
166
+ # import ipdb; ipdb.set_trace()
167
+
168
+
169
+
170
+ # fname = render_motion(AIT_RENDERER, [edited_mot_to_render],
171
+ # f"movie_example--{str(xx)}",
172
+ # pose_repr='aa',
173
+ # color=[color_map['generated']],
174
+ # smpl_layer=SMPL_LAYER)
175
+ print(fname)
176
+ print(os.path.abspath(fname))
177
  return fname
178
 
179
+ from huggingface_hub import hf_hub_download
180
 
181
  def download_models():
182
  REPO_ID = 'athn-nik/example-model'
183
  return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
184
 
185
+ def download_motion_from_dataset(key_to_dl):
186
+ REPO_ID = 'athn-nik/example-model'
187
+ from huggingface_hub import snapshot_download
188
+ keytodl = key_to_dl
189
+ keytodl = '000008'
190
+ path_for_ds = snapshot_download(repo_id=REPO_ID,
191
+ allow_patterns=f"dataset_inputs/{keytodl}",
192
+ token=access_token_smpl)
193
+ path_for_ds_sample = path_for_ds + f'/dataset_inputs/{keytodl}.pth.tar'
194
+ return path_for_ds_sample
195
+
196
  def download_tmr():
197
  REPO_ID = 'athn-nik/example-model'
198
  # return hf_hub_download(REPO_ID, filename="min_checkpoint.ckpt")
 
236
  random_key = random.choice(list(current_set.keys()))
237
  curvid = current_set[random_key]['motion_a']
238
  text_annot = current_set[random_key]['annotation']
239
+ return curvid, text_annot, random_key, text_annot
240
 
241
  def retrieve_video(retrieve_text):
242
  tmr_text_encoder = get_tmr_model(download_tmr())
 
261
  text_annot = all_mots[top_mot]['annotation']
262
  return curvid, text_annot
263
 
264
+
265
  with gr.Blocks(css="""
266
  .gradio-row {
267
  display: flex;
 
297
  }
298
  """) as demo:
299
  gr.Markdown(WEBSITE)
300
+ random_key_state = gr.State()
301
+
302
  with gr.Row(elem_id="gradio-row"):
303
+ with gr.Column(scale=5, elem_id="gradio-column"):
304
  gr.Markdown(WEB_source)
305
  with gr.Row(elem_id="gradio-button-row"):
306
+ # iterative_button = gr.Button("Iterative")
307
+ # retrieve_button = gr.Button("TMRetrieve")
308
  random_button = gr.Button("Random")
309
 
310
  with gr.Row(elem_id="gradio-textbox-row"):
311
  with gr.Column(scale=5, elem_id="gradio-textbox-with-button"):
312
+ # retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:",
313
+ # show_label=True, label="Retrieval Text",
314
+ # value=DEFAULT_TEXT)
315
  clear_button_retrieval = gr.Button("Clear", scale=0)
316
 
317
  with gr.Row(elem_id="gradio-textbox-row"):
 
324
  value='all',
325
  label="Set to pick from",
326
  info="Motion will be picked from whole dataset or test or train data.")
327
+ # import ipdb; ipdb.set_trace()
328
  retrieved_video_output = gr.Video(label="Retrieved Motion",
329
+ # value=xxx,
330
  height=360, width=480)
331
+
332
 
333
  with gr.Column(scale=5, elem_id="gradio-column"):
334
  gr.Markdown(WEB_target)
 
344
  video_output = gr.Video(label="Generated Video", height=360,
345
  width=480)
346
 
347
+ def process_and_show_video(input_text, random_key_state):
348
+ fname = show_video(input_text, random_key_state)
349
  return fname
350
 
351
  def process_and_retrieve_video(input_text):
 
355
  from retrieval_loader import get_tmr_model
356
  from dataset_utils import load_motionfix
357
 
358
+ edit_button.click(process_and_show_video, inputs=[input_text, random_key_state], outputs=video_output)
359
+ # retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=[retrieved_video_output, suggested_edit_text])
360
+ random_button.click(random_source_motion, inputs=set_to_pick,
361
+ outputs=[retrieved_video_output,
362
+ suggested_edit_text,
363
+ random_key_state,
364
+ input_text])
365
  clear_button_edit.click(clear, outputs=input_text)
366
+ # clear_button_retrieval.click(clear, outputs=retrieve_text)
367
+ gr.Markdown(CREDITS)
368
 
369
+ demo.launch(share=True)
body_renderer.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_render(body_model_loaded,
2
+ body_trans,
3
+ body_orient, body_pose,
4
+ output_path, text='',
5
+ colors=[]):
6
+ from renderer.utils import run_smpl_fwd_vertices
7
+
8
+ vertices_list=[]
9
+ if not isinstance(body_trans, list):
10
+ body_trans = [body_trans]
11
+ if not isinstance(body_orient, list):
12
+ body_orient = [body_orient]
13
+ if not isinstance(body_pose, list):
14
+ body_pose = [body_pose]
15
+
16
+ for trans, orient,pose in zip(body_trans,body_orient,body_pose):
17
+
18
+ vertices= run_smpl_fwd_vertices(body_model_loaded,
19
+ trans,
20
+ orient,
21
+ pose)
22
+
23
+ vertices=vertices.vertices
24
+ vertices = vertices.detach().cpu().numpy()
25
+ vertices_list.append(vertices)
26
+
27
+ #Initialising the renderer
28
+ from renderer.humor import HumorRenderer
29
+ fps = 30.0
30
+ imw = 720 # 480
31
+ imh = 720 # 360
32
+ renderer = HumorRenderer(fps=fps, imw=imw, imh=imh)
33
+
34
+ if len(vertices_list)==2:
35
+ renderer(vertices_list, output_path, render_pair=True,
36
+ fps=fps,colors=colors)
37
+ else:
38
+ renderer(vertices_list[0], output_path, render_pair=False,
39
+ fps=fps,colors=colors)
40
+ return output_path
renderer.py DELETED
@@ -1,138 +0,0 @@
1
- import os
2
- import torch
3
- from transform3d import transform_body_pose
4
- from aitviewer.headless import HeadlessRenderer
5
- from gen_utils import rgb, rgba
6
-
7
- color_map = {
8
- 'source_motion': rgba('darkred'),
9
- 'source': rgba('darkred'),
10
- 'target_motion': rgba('olivedrab'),
11
- 'input': rgba('olivedrab'),
12
- 'target': rgba('olivedrab'),
13
- 'generation': rgba('purple'),
14
- 'generated': rgba('steelblue'),
15
- 'denoised': rgba('purple'),
16
- 'noised': rgba('darkgrey'),
17
- }
18
-
19
-
20
- def pack_to_render(rots, trans, pose_repr='6d'):
21
- # make axis-angle
22
- # global_orient = transform_body_pose(rots, f"{pose_repr}->aa")
23
-
24
- if rots.is_cuda:
25
- rots = rots.detach().cpu()
26
- if trans.is_cuda:
27
- trans = trans.detach().cpu()
28
-
29
- if pose_repr != 'aa':
30
- body_pose = transform_body_pose(rots, f"{pose_repr}->aa")
31
- else:
32
- body_pose = rots
33
- if trans is None:
34
- trans = torch.zeros((rots.shape[0], rots.shape[1], 3),
35
- device=rots.device)
36
- render_d = {'body_transl': trans,
37
- 'body_orient': body_pose[..., :3],
38
- 'body_pose': body_pose[..., 3:]}
39
- return render_d
40
-
41
-
42
- def render_motion(renderer: HeadlessRenderer, datum: dict,
43
- filename: str, pose_repr='6d',
44
- color=(160 / 255, 160 / 255, 160 / 255, 1.0),
45
- return_verts=False, smpl_layer=None) -> None:
46
- """
47
- Function to render a video of a motion sequence
48
- renderer: aitviewer renderer
49
- datum: dictionary containing sequence of poses, body translations and body orientations
50
- data could be numpy or pytorch tensors
51
- filename: the absolute path you want the video to be saved at
52
-
53
- """
54
- from aitviewer.headless import HeadlessRenderer
55
- from aitviewer.renderables.smpl import SMPLSequence
56
-
57
- if isinstance(datum, dict): datum = [datum]
58
- if not isinstance(color, list):
59
- colors = [color]
60
- else:
61
- colors = color
62
- # assert {'body_transl', 'body_orient', 'body_pose'}.issubset(set(datum[0].keys()))
63
- # os.environ['DISPLAY'] = ":11"
64
- gender = 'neutral'
65
- only_skel = False
66
- import sys
67
- seqs_of_human_motions = []
68
- if smpl_layer is None:
69
- from aitviewer.models.smpl import SMPLLayer
70
- smpl_layer = SMPLLayer(model_type='smplh',
71
- ext='npz',
72
- gender=gender)
73
-
74
- for iid, mesh_seq in enumerate(datum):
75
-
76
- if pose_repr != 'aa':
77
- global_orient = transform_body_pose(mesh_seq['body_orient'],
78
- f"{pose_repr}->aa")
79
- body_pose = transform_body_pose(mesh_seq['body_pose'],
80
- f"{pose_repr}->aa")
81
- else:
82
- global_orient = mesh_seq['body_orient']
83
- body_pose = mesh_seq['body_pose']
84
-
85
- body_transl = mesh_seq['body_transl']
86
- sys.stdout.flush()
87
-
88
- old = os.dup(1)
89
- os.close(1)
90
- os.open(os.devnull, os.O_WRONLY)
91
- print(body_pose.shape)
92
- print('\n')
93
- smpl_template = SMPLSequence(body_pose,
94
- smpl_layer,
95
- poses_root=global_orient,
96
- trans=body_transl,
97
- color=colors[iid],
98
- z_up=True)
99
- if only_skel:
100
- smpl_template.remove(smpl_template.mesh_seq)
101
-
102
- seqs_of_human_motions.append(smpl_template)
103
- renderer.scene.add(smpl_template)
104
- # camera follows smpl sequence
105
- # FIX CAMERA
106
- from transform3d import get_z_rot
107
- R_z = get_z_rot(global_orient[0], in_format='aa')
108
- heading = -R_z[:, 1]
109
- xy_facing = body_transl[0] + heading*2.5
110
- camera = renderer.lock_to_node(seqs_of_human_motions[0],
111
- (xy_facing[0], xy_facing[1], 1.5), smooth_sigma=5.0)
112
-
113
- # /FIX CAMERA
114
- if len(mesh_seq['body_pose']) == 1:
115
- renderer.save_frame(file_path=str(filename) + '.png')
116
- sfx = 'png'
117
- else:
118
- renderer.save_video(video_dir=str(filename), output_fps=30)
119
- sfx = 'mp4'
120
-
121
- # aitviewer adds a counter to the filename, we remove it
122
- # filename.split('_')[-1].replace('.mp4', '')
123
- # os.rename(filename + '_0.mp4', filename[:-4] + '.mp4')
124
- if sfx == 'mp4':
125
- os.rename(str(filename) + f'_0.{sfx}', str(filename) + f'.{sfx}')
126
-
127
- # empty scene for the next rendering
128
- for mesh in seqs_of_human_motions:
129
- renderer.scene.remove(mesh)
130
- renderer.scene.remove(camera)
131
-
132
- sys.stdout.flush()
133
- os.close(1)
134
- os.dup(old)
135
- os.close(old)
136
- renderer.reset()
137
- fname = f'{filename}.{sfx}'
138
- return fname
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
renderer/__init__.py ADDED
File without changes
renderer/humor.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from .humor_render_tools.tools import viz_smpl_seq
4
+ from smplx.utils import Struct
5
+ from .video import Video
6
+ import os
7
+ from multiprocessing import Pool
8
+ from tqdm import tqdm
9
+ from multiprocessing import Process
10
+
11
+ THIS_FOLDER = os.path.dirname(os.path.abspath(__file__))
12
+ FACE_PATH = os.path.join(THIS_FOLDER, "humor_render_tools/smplh.faces")
13
+ FACES = torch.from_numpy(np.int32(np.load(FACE_PATH)))
14
+
15
+
16
+ class HumorRenderer:
17
+ def __init__(self, fps=20.0, **kwargs):
18
+ self.kwargs = kwargs
19
+ self.fps = fps
20
+
21
+ def __call__(self, vertices, output, render_pair=False,text=None,colors=[],**kwargs):
22
+ params = self.kwargs | kwargs
23
+ fps = self.fps
24
+ if "fps" in params:
25
+ fps = params.pop("fps")
26
+
27
+ if render_pair:
28
+ render_overlaid(vertices, output,fps,colors,**params)
29
+ else:
30
+ render(vertices, output, fps,colors,**params)
31
+
32
+ fname = f'{output}.mp4'
33
+ return fname
34
+
35
+
36
+ def render_overlaid(vertices, out_path, fps,colors=[],progress_bar=tqdm,**kwargs):
37
+ assert isinstance(vertices, list) and len(vertices) == 2
38
+ # Put the vertices at the floor level
39
+ # F X N x 3 ===> [F X N x 3, F X N x 3]
40
+ ground = vertices[0][..., 2].min()
41
+ vertices[0][..., 2] -= ground
42
+ vertices[1][..., 2] -= ground
43
+ verts0 = vertices[0]
44
+ verts1 = vertices[1]
45
+
46
+ import pyrender
47
+
48
+ # remove title if it exists
49
+ kwargs.pop("title", None)
50
+
51
+ # vertices: SMPL-H vertices
52
+ # verts = np.load("interval_2_verts.npy")
53
+ out_folder = os.path.splitext(out_path)[0]
54
+
55
+ verts0 = torch.from_numpy(verts0)
56
+ body_pred0 = Struct(v=verts0, f=FACES)
57
+ verts1 = torch.from_numpy(verts1)
58
+ body_pred1 = Struct(v=verts1, f=FACES)
59
+
60
+
61
+ # out_folder, body_pred, start, end, fps, kwargs = args
62
+ viz_smpl_seq(
63
+ pyrender, out_folder, [body_pred0, body_pred1], fps=fps,progress_bar=progress_bar,vertex_color_list=colors, **kwargs
64
+ )
65
+
66
+ video = Video(out_folder, fps=fps)
67
+ video.save(out_path)
68
+
69
+ def render(vertices, out_path, fps,colors=[], progress_bar=tqdm,**kwargs):
70
+ # Put the vertices at the floor level
71
+ ground = vertices[..., 2].min()
72
+ vertices[..., 2] -= ground
73
+
74
+ import pyrender
75
+
76
+ # remove title if it exists
77
+ kwargs.pop("title", None)
78
+
79
+ # vertices: SMPL-H vertices
80
+ # verts = np.load("interval_2_verts.npy")
81
+ out_folder = os.path.splitext(out_path)[0]
82
+
83
+ verts = torch.from_numpy(vertices)
84
+ body_pred = Struct(v=verts, f=FACES)
85
+
86
+ # out_folder, body_pred, start, end, fps, kwargs = args
87
+ viz_smpl_seq(
88
+ pyrender, out_folder, body_pred, fps=fps,progress_bar=progress_bar,vertex_color=colors, **kwargs
89
+ )
90
+
91
+ video = Video(out_folder, fps=fps)
92
+ video.save(out_path)
93
+ import shutil
94
+ shutil.rmtree(out_folder)
95
+
96
+ def render_offset(args):
97
+ import pyrender
98
+
99
+ out_folder, body_pred, start, end, fps, kwargs = args
100
+ viz_smpl_seq(
101
+ pyrender, out_folder, body_pred, start=start, end=end, fps=fps, **kwargs
102
+ )
103
+ return 0
104
+
105
+
106
+ def render_multiprocess(vertices, out_path, fps, **kwargs):
107
+ # WIP: does not work yet
108
+ import ipdb
109
+
110
+ ipdb.set_trace()
111
+ # remove title if it exists
112
+ kwargs.pop("title", None)
113
+
114
+ # vertices: SMPL-H vertices
115
+ # verts = np.load("interval_2_verts.npy")
116
+ out_folder = os.path.splitext(out_path)[0]
117
+
118
+ verts = torch.from_numpy(vertices)
119
+ body_pred = Struct(v=verts, f=FACES)
120
+
121
+ # faster rendering
122
+ # by rendering part of the sequence in parallel
123
+ # still work in progress, use one process for now
124
+ n_processes = 1
125
+
126
+ verts_lst = np.array_split(verts, n_processes)
127
+ len_split = [len(x) for x in verts_lst]
128
+ starts = [0] + np.cumsum([x for x in len_split[:-1]]).tolist()
129
+ ends = np.cumsum([x for x in len_split]).tolist()
130
+ out_folders = [out_folder for _ in range(n_processes)]
131
+ fps_s = [fps for _ in range(n_processes)]
132
+ kwargs_s = [kwargs for _ in range(n_processes)]
133
+ body_pred_s = [body_pred for _ in range(n_processes)]
134
+
135
+ arguments = [out_folders, body_pred_s, starts, ends, fps_s, kwargs_s]
136
+ # sanity
137
+ # lst = [verts[start:end] for start, end in zip(starts, ends)]
138
+ # assert (torch.cat(lst) == verts).all()
139
+
140
+ processes = []
141
+ for _, args in zip(range(n_processes), zip(*arguments)):
142
+ process = Process(target=render_offset, args=(args,))
143
+ process.start()
144
+ processes.append(process)
145
+
146
+ for process in processes:
147
+ process.join()
148
+
149
+ if False:
150
+ # start 4 worker processes
151
+ with Pool(processes=n_processes) as pool:
152
+ # print "[0, 1, 4,..., 81]"
153
+ # print same numbers in arbitrary order
154
+ print(f"0/{n_processes} rendered")
155
+ i = 0
156
+ for _ in pool.imap_unordered(render_offset, zip(*arguments)):
157
+ i += 1
158
+ print(f"i/{n_processes} rendered")
159
+
160
+ video = Video(out_folder, fps=fps)
161
+ video.save(out_path)
renderer/humor_render_tools/mesh_viewer.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from HuMoR
2
+ import os
3
+ import time
4
+
5
+ # import math
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import trimesh
9
+
10
+ # import pyrender
11
+ import sys
12
+ import cv2
13
+
14
+ from .parameters import colors
15
+
16
+ # import pyglet
17
+
18
+ __all__ = ["MeshViewer"]
19
+
20
+ COMPRESS_PARAMS = [cv2.IMWRITE_PNG_COMPRESSION, 9]
21
+
22
+
23
+ def pause_play_callback(pyrender_viewer, mesh_viewer):
24
+ mesh_viewer.is_paused = not mesh_viewer.is_paused
25
+
26
+
27
+ def step_callback(pyrender_viewer, mesh_viewer, step_size):
28
+ mesh_viewer.animation_frame_idx = (
29
+ mesh_viewer.animation_frame_idx + step_size
30
+ ) % mesh_viewer.animation_len
31
+
32
+
33
+ class MeshViewer(object):
34
+ def __init__(
35
+ self,
36
+ pyrender,
37
+ width=1200,
38
+ height=800,
39
+ use_offscreen=False,
40
+ follow_camera=False,
41
+ camera_intrinsics=None,
42
+ img_extn="png",
43
+ default_cam_offset=[0.0, 4.0, 1.25],
44
+ default_cam_rot=None,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.pyrender = pyrender
49
+ self.use_offscreen = use_offscreen
50
+ self.follow_camera = follow_camera
51
+ # render settings for offscreen
52
+ self.render_wireframe = False
53
+ self.render_RGBA = False
54
+ self.render_path = "./render_out"
55
+ self.img_extn = img_extn
56
+
57
+ # mesh sequences to animate
58
+ self.animated_seqs = [] # the actual sequence of pyrender meshes
59
+ self.animated_seqs_type = []
60
+ self.animated_nodes = [] # the nodes corresponding to each sequence
61
+ self.light_nodes = []
62
+ # they must all be the same length (set based on first given sequence)
63
+ self.animation_len = -1
64
+ # current index in the animation sequence
65
+ self.animation_frame_idx = 0
66
+ # track render time to keep steady framerate
67
+ self.animation_render_time = time.time()
68
+ # background image sequence
69
+ self.img_seq = None
70
+ self.cur_bg_img = None
71
+ # person mask sequence
72
+ self.mask_seq = None
73
+ self.cur_mask = None
74
+
75
+ self.single_frame = False
76
+
77
+ self.mat_constructor = self.pyrender.MetallicRoughnessMaterial
78
+ self.trimesh_to_pymesh = self.pyrender.Mesh.from_trimesh
79
+
80
+ self.scene = self.pyrender.Scene(
81
+ bg_color=colors["white"], ambient_light=(0.3, 0.3, 0.3)
82
+ )
83
+
84
+ self.default_cam_offset = np.array(default_cam_offset)
85
+ self.default_cam_rot = np.array(default_cam_rot)
86
+
87
+ self.default_cam_pose = np.eye(4)
88
+ if default_cam_rot is None:
89
+ self.default_cam_pose = trimesh.transformations.rotation_matrix(
90
+ np.radians(180), (0, 0, 1)
91
+ )
92
+ self.default_cam_pose = np.dot(
93
+ trimesh.transformations.rotation_matrix(np.radians(-90), (1, 0, 0)),
94
+ self.default_cam_pose,
95
+ )
96
+ else:
97
+ self.default_cam_pose[:3, :3] = self.default_cam_rot
98
+ self.default_cam_pose[:3, 3] = self.default_cam_offset
99
+
100
+ self.use_intrins = False
101
+ if camera_intrinsics is None:
102
+ pc = self.pyrender.PerspectiveCamera(
103
+ yfov=np.pi / 3.0, aspectRatio=float(width) / height
104
+ )
105
+ camera_pose = self.get_init_cam_pose()
106
+ self.camera_node = self.scene.add(pc, pose=camera_pose, name="pc-camera")
107
+
108
+ light = self.pyrender.DirectionalLight(color=np.ones(3), intensity=1.0)
109
+ self.scene.add(light, pose=self.default_cam_pose)
110
+ else:
111
+ self.use_intrins = True
112
+ fx, fy, cx, cy = camera_intrinsics
113
+ camera_pose = np.eye(4)
114
+ camera_pose = np.array([1.0, -1.0, -1.0, 1.0]).reshape(-1, 1) * camera_pose
115
+ camera = self.pyrender.camera.IntrinsicsCamera(fx=fx, fy=fy, cx=cx, cy=cy)
116
+ self.camera_node = self.scene.add(
117
+ camera, pose=camera_pose, name="pc-camera"
118
+ )
119
+
120
+ light = self.pyrender.DirectionalLight(color=np.ones(3), intensity=1.0)
121
+ self.scene.add(light, pose=camera_pose)
122
+
123
+ self.set_background_color([1.0, 1.0, 1.0, 0.0])
124
+
125
+ self.figsize = (width, height)
126
+
127
+ # key callbacks
128
+ self.is_paused = False
129
+ registered_keys = dict()
130
+ registered_keys["p"] = (pause_play_callback, [self])
131
+ registered_keys["."] = (step_callback, [self, 1])
132
+ registered_keys[","] = (step_callback, [self, -1])
133
+
134
+ if self.use_offscreen:
135
+ self.viewer = self.pyrender.OffscreenRenderer(
136
+ *self.figsize, point_size=2.75
137
+ )
138
+ self.use_raymond_lighting(3.5)
139
+ else:
140
+ self.viewer = self.pyrender.Viewer(
141
+ self.scene,
142
+ use_raymond_lighting=(not camera_intrinsics),
143
+ viewport_size=self.figsize,
144
+ cull_faces=False,
145
+ run_in_thread=True,
146
+ registered_keys=registered_keys,
147
+ )
148
+
149
+ def get_init_cam_pose(self):
150
+ camera_pose = self.default_cam_pose.copy()
151
+ return camera_pose
152
+
153
+ def set_background_color(self, color=colors["white"]):
154
+ self.scene.bg_color = color
155
+
156
+ def update_camera_pose(self, camera_pose):
157
+ self.scene.set_pose(self.camera_node, pose=camera_pose)
158
+
159
+ def close_viewer(self):
160
+ if self.viewer.is_active:
161
+ self.viewer.close_external()
162
+
163
+ def set_meshes(self, meshes, group_name="static"):
164
+ for node in self.scene.get_nodes():
165
+ if node.name is not None and "%s-mesh" % group_name in node.name:
166
+ self.scene.remove_node(node)
167
+
168
+ for mid, mesh in enumerate(meshes):
169
+ if isinstance(mesh, trimesh.Trimesh):
170
+ mesh = self.pyrender.Mesh.from_trimesh(mesh.copy())
171
+ self.acquire_render_lock()
172
+ self.scene.add(mesh, "%s-mesh-%2d" % (group_name, mid))
173
+ self.release_render_lock()
174
+
175
+ def set_static_meshes(self, meshes):
176
+ self.set_meshes(meshes, group_name="static")
177
+
178
+ def add_static_meshes(self, meshes):
179
+ for mid, mesh in enumerate(meshes):
180
+ if isinstance(mesh, trimesh.Trimesh):
181
+ mesh = self.pyrender.Mesh.from_trimesh(mesh.copy())
182
+ self.acquire_render_lock()
183
+ self.scene.add(mesh, "%s-mesh-%2d" % ("staticadd", mid))
184
+ self.release_render_lock()
185
+
186
+ def add_smpl_vtx_list_seq(
187
+ self, body_mesh_seq, vtx_list, color=[0.0, 0.0, 1.0], radius=0.015
188
+ ):
189
+ vtx_point_seq = []
190
+ for mesh in body_mesh_seq:
191
+ vtx_point_seq.append(mesh.vertices[vtx_list])
192
+ self.add_point_seq(vtx_point_seq, color=color, radius=radius)
193
+
194
+ def set_img_seq(self, img_seq):
195
+ """
196
+ np array of BG images to be rendered in background.
197
+ """
198
+ if not self.use_offscreen:
199
+ print("Cannot render background image if not rendering offscreen")
200
+ return
201
+ # ensure same length as other sequences
202
+ cur_seq_len = len(img_seq)
203
+ if self.animation_len != -1:
204
+ if cur_seq_len != self.animation_len:
205
+ print(
206
+ "Unexpected imgage sequence length, all sequences must be the same length!"
207
+ )
208
+ return
209
+ else:
210
+ if cur_seq_len > 0:
211
+ self.animation_len = cur_seq_len
212
+ else:
213
+ print("Warning: imge sequence is length 0!")
214
+ return
215
+
216
+ self.img_seq = img_seq
217
+ # must have alpha to render background
218
+ self.set_render_settings(RGBA=True)
219
+
220
+ def set_mask_seq(self, mask_seq):
221
+ """
222
+ np array of masked images to be rendered in background.
223
+ """
224
+ if not self.use_offscreen:
225
+ print("Cannot render background image if not rendering offscreen")
226
+ return
227
+ # ensure same length as other sequences
228
+ cur_seq_len = len(mask_seq)
229
+ if self.animation_len != -1:
230
+ if cur_seq_len != self.animation_len:
231
+ print(
232
+ "Unexpected imgage sequence length, all sequences must be the same length!"
233
+ )
234
+ return
235
+ else:
236
+ if cur_seq_len > 0:
237
+ self.animation_len = cur_seq_len
238
+ else:
239
+ print("Warning: imge sequence is length 0!")
240
+ return
241
+
242
+ self.mask_seq = mask_seq
243
+
244
+ def add_point_seq(
245
+ self,
246
+ point_seq,
247
+ color=[1.0, 0.0, 0.0],
248
+ radius=0.015,
249
+ contact_seq=None,
250
+ contact_color=[0.0, 1.0, 0.0],
251
+ connections=None,
252
+ connect_color=[0.0, 0.0, 1.0],
253
+ vel=None,
254
+ render_static=None,
255
+ ):
256
+ """
257
+ Add a sequence of points that will be visualized as spheres.
258
+
259
+ - points : List of Nx3 numpy arrays of point locations to visualize as sequence.
260
+ - color : list of 3 RGB values
261
+ - radius : radius of each point
262
+ - contact_seq : an array of num_frames x num_points indicatin "contacts" i.e. points that should be colored
263
+ differently at different time steps.
264
+ - connections : array of point index pairs, draws a cylinder between each pair to create skeleton
265
+ - vel : list of Nx3 numpy arrays for the velocities of corresponding sequence points
266
+ """
267
+ # ensure same length as other sequences
268
+ cur_seq_len = len(point_seq)
269
+ if self.animation_len != -1:
270
+ if cur_seq_len != self.animation_len:
271
+ print(
272
+ "Unexpected sequence length, all sequences must be the same length!"
273
+ )
274
+ return
275
+ else:
276
+ if cur_seq_len > 0:
277
+ self.animation_len = cur_seq_len
278
+ else:
279
+ print("Warning: points sequence is length 0!")
280
+ return
281
+
282
+ num_joints = point_seq[0].shape[0]
283
+ if contact_seq is not None and contact_seq.shape[1] != num_joints:
284
+ print(num_joints)
285
+ print(contact_seq.shape)
286
+ print(
287
+ "Contact sequence must have the same number of points as the input joints!"
288
+ )
289
+ return
290
+ if contact_seq is not None and contact_seq.shape[0] != cur_seq_len:
291
+ print(
292
+ "Contact sequence must have the same number of frames as the input sequence!"
293
+ )
294
+ return
295
+
296
+ # add skeleton
297
+ if connections is not None:
298
+ pyrender_skeleton_seq = []
299
+ for pid, points in enumerate(point_seq):
300
+ if pid % 200 == 0:
301
+ print(
302
+ "Caching pyrender connections mesh %d/%d..."
303
+ % (pid, len(point_seq))
304
+ )
305
+
306
+ cyl_mesh_list = []
307
+ for point_pair in connections:
308
+ # print(point_pair)
309
+ p1 = points[point_pair[0]]
310
+ p2 = points[point_pair[1]]
311
+ if np.linalg.norm(p1 - p2) < 1e-6:
312
+ segment = np.array([[-1.0, -1.0, -1.0], [-1.01, -1.01, -1.01]])
313
+ else:
314
+ segment = np.array([p1, p2])
315
+ # print(segment)
316
+
317
+ cyl_mesh = trimesh.creation.cylinder(
318
+ radius * 0.35, height=None, segment=segment
319
+ )
320
+ cyl_mesh.visual.vertex_colors = connect_color
321
+ cyl_mesh_list.append(cyl_mesh.copy())
322
+
323
+ # combine
324
+ m = self.pyrender.Mesh.from_trimesh(cyl_mesh_list)
325
+ pyrender_skeleton_seq.append(m)
326
+
327
+ if render_static is None:
328
+ self.add_pyrender_mesh_seq(pyrender_skeleton_seq)
329
+ else:
330
+ self.add_static_meshes(
331
+ [
332
+ pyrender_skeleton_seq[i]
333
+ for i in range(len(pyrender_skeleton_seq))
334
+ if i % render_static == 0
335
+ ]
336
+ )
337
+
338
+ # add velocities
339
+ if vel is not None:
340
+ print("Caching pyrender velocities mesh...")
341
+ pyrender_vel_seq = []
342
+
343
+ point_vel_pairs = zip(point_seq, vel)
344
+ for pid, point_vel_pair in enumerate(point_vel_pairs):
345
+ cur_point_seq, cur_vel_seq = point_vel_pair
346
+
347
+ cyl_mesh_list = []
348
+ for cur_point, cur_vel in zip(cur_point_seq, cur_vel_seq):
349
+ p1 = cur_point
350
+ p2 = cur_point + cur_vel * 0.1
351
+ segment = np.array([p1, p2])
352
+ if np.linalg.norm(p1 - p2) < 1e-6:
353
+ continue
354
+ cyl_mesh = trimesh.creation.cylinder(
355
+ radius * 0.1, height=None, segment=segment
356
+ )
357
+ cyl_mesh.visual.vertex_colors = [0.0, 0.0, 1.0]
358
+ cyl_mesh_list.append(cyl_mesh.copy())
359
+
360
+ # combine
361
+ m = self.pyrender.Mesh.from_trimesh(cyl_mesh_list)
362
+ pyrender_vel_seq.append(m)
363
+
364
+ if render_static is None:
365
+ self.add_pyrender_mesh_seq(pyrender_vel_seq)
366
+ else:
367
+ self.add_static_meshes(
368
+ [
369
+ pyrender_vel_seq[i]
370
+ for i in range(len(pyrender_vel_seq))
371
+ if i % render_static == 0
372
+ ]
373
+ )
374
+
375
+ # create spheres with trimesh
376
+ if contact_seq is None:
377
+ contact_seq = [
378
+ np.zeros((point_seq[t].shape[0])) for t in range(cur_seq_len)
379
+ ]
380
+ pyrender_non_contact_point_seq = []
381
+ pyrender_contact_point_seq = []
382
+ for pid, points in enumerate(point_seq):
383
+ if pid % 200 == 0:
384
+ print("Caching pyrender points mesh %d/%d..." % (pid, len(point_seq)))
385
+
386
+ # first non-contacting points
387
+ if len(color) > 3:
388
+ pyrender_non_contact_point_seq.append(
389
+ self.pyrender.Mesh.from_points(points, color[pid])
390
+ )
391
+ else:
392
+ sm = trimesh.creation.uv_sphere(radius=radius)
393
+ sm.visual.vertex_colors = color
394
+ non_contact_points = points[contact_seq[pid] == 0]
395
+ if len(non_contact_points) > 0:
396
+ tfs = np.tile(np.eye(4), (len(non_contact_points), 1, 1))
397
+ tfs[:, :3, 3] = non_contact_points.copy()
398
+ m = self.pyrender.Mesh.from_trimesh(sm.copy(), poses=tfs)
399
+ pyrender_non_contact_point_seq.append(m)
400
+ else:
401
+ tfs = np.eye(4).reshape((1, 4, 4))
402
+ tfs[0, :3, 3] = np.array([0, 0, 30.0])
403
+ pyrender_non_contact_point_seq.append(
404
+ self.pyrender.Mesh.from_trimesh(sm.copy(), poses=tfs)
405
+ )
406
+ # then contacting points
407
+ sm = trimesh.creation.uv_sphere(radius=radius)
408
+ sm.visual.vertex_colors = contact_color
409
+ contact_points = points[contact_seq[pid] == 1]
410
+ if len(contact_points) > 0:
411
+ tfs = np.tile(np.eye(4), (len(contact_points), 1, 1))
412
+ tfs[:, :3, 3] = contact_points.copy()
413
+ m = self.pyrender.Mesh.from_trimesh(sm.copy(), poses=tfs)
414
+ pyrender_contact_point_seq.append(m)
415
+ else:
416
+ tfs = np.eye(4).reshape((1, 4, 4))
417
+ tfs[0, :3, 3] = np.array([0, 0, 30.0])
418
+ pyrender_contact_point_seq.append(
419
+ self.pyrender.Mesh.from_trimesh(sm.copy(), poses=tfs)
420
+ )
421
+
422
+ if len(pyrender_non_contact_point_seq) > 0:
423
+ if render_static is None:
424
+ self.add_pyrender_mesh_seq(
425
+ pyrender_non_contact_point_seq, seq_type="point"
426
+ )
427
+ else:
428
+ self.add_static_meshes(
429
+ [
430
+ pyrender_non_contact_point_seq[i]
431
+ for i in range(len(pyrender_non_contact_point_seq))
432
+ if i % render_static == 0
433
+ ]
434
+ )
435
+ if len(pyrender_contact_point_seq) > 0:
436
+ if render_static is None:
437
+ self.add_pyrender_mesh_seq(pyrender_contact_point_seq, seq_type="point")
438
+ else:
439
+ self.add_static_meshes(
440
+ [
441
+ pyrender_contact_point_seq[i]
442
+ for i in range(len(pyrender_contact_point_seq))
443
+ if i % render_static == 0
444
+ ]
445
+ )
446
+
447
+ def add_mesh_seq(self, mesh_seq, progress_bar=tqdm):
448
+ """
449
+ Add a sequence of trimeshes to render.
450
+
451
+ - meshes : List of trimesh.trimesh objects giving each frame of the sequence.
452
+ """
453
+
454
+ # ensure same length as other sequences
455
+ cur_seq_len = len(mesh_seq)
456
+ if self.animation_len != -1:
457
+ if cur_seq_len != self.animation_len:
458
+ print(
459
+ "Unexpected sequence length, all sequences must be the same length!"
460
+ )
461
+ return
462
+ else:
463
+ if cur_seq_len > 0:
464
+ self.animation_len = cur_seq_len
465
+ else:
466
+ print("Warning: mesh sequence is length 0!")
467
+ return
468
+
469
+ # print("Adding mesh sequence with %d frames..." % (cur_seq_len))
470
+
471
+ # create sequence of pyrender meshes and save
472
+ pyrender_mesh_seq = []
473
+
474
+ iterator = enumerate(mesh_seq)
475
+ if progress_bar is not None:
476
+ iterator = progress_bar(list(iterator), desc="Import meshes in pyrender")
477
+
478
+ for mid, mesh in iterator:
479
+ if isinstance(mesh, trimesh.Trimesh):
480
+ mesh = self.pyrender.Mesh.from_trimesh(mesh.copy())
481
+ pyrender_mesh_seq.append(mesh)
482
+ else:
483
+ print("Meshes must be from trimesh!")
484
+ return
485
+
486
+ self.add_pyrender_mesh_seq(pyrender_mesh_seq, seq_type="mesh")
487
+
488
+ def add_pyrender_mesh_seq(self, pyrender_mesh_seq, seq_type="default"):
489
+ # add to the list of sequences to render
490
+ seq_id = len(self.animated_seqs)
491
+ self.animated_seqs.append(pyrender_mesh_seq)
492
+ self.animated_seqs_type.append(seq_type)
493
+
494
+ # create the corresponding node in the scene
495
+ self.acquire_render_lock()
496
+ anim_node = self.scene.add(pyrender_mesh_seq[0], "anim-mesh-%2d" % (seq_id))
497
+ self.animated_nodes.append(anim_node)
498
+ self.release_render_lock()
499
+
500
+ def add_ground(
501
+ self,
502
+ ground_plane=None,
503
+ length=25.0,
504
+ color0=[0.8, 0.9, 0.9],
505
+ color1=[0.6, 0.7, 0.7],
506
+ tile_width=0.5,
507
+ xyz_orig=None,
508
+ alpha=1.0,
509
+ ):
510
+ """
511
+ If ground_plane is none just places at origin with +z up.
512
+ If ground_plane is given (a, b, c, d) where a,b,c is the normal, then this is rendered. To more accurately place the floor
513
+ provid an xyz_orig = [x,y,z] that we expect to be near the point of focus.
514
+ """
515
+ color0 = np.array(color0 + [alpha])
516
+ color1 = np.array(color1 + [alpha])
517
+ # make checkerboard
518
+ radius = length / 2.0
519
+ num_rows = num_cols = int(length / tile_width)
520
+ vertices = []
521
+ faces = []
522
+ face_colors = []
523
+ for i in range(num_rows):
524
+ for j in range(num_cols):
525
+ start_loc = [-radius + j * tile_width, radius - i * tile_width]
526
+ cur_verts = np.array(
527
+ [
528
+ [start_loc[0], start_loc[1], 0.0],
529
+ [start_loc[0], start_loc[1] - tile_width, 0.0],
530
+ [start_loc[0] + tile_width, start_loc[1] - tile_width, 0.0],
531
+ [start_loc[0] + tile_width, start_loc[1], 0.0],
532
+ ]
533
+ )
534
+ cur_faces = np.array([[0, 1, 3], [1, 2, 3]], dtype=int)
535
+ cur_faces += 4 * (
536
+ i * num_cols + j
537
+ ) # the number of previously added verts
538
+ use_color0 = (i % 2 == 0 and j % 2 == 0) or (i % 2 == 1 and j % 2 == 1)
539
+ cur_color = color0 if use_color0 else color1
540
+ cur_face_colors = np.array([cur_color, cur_color])
541
+
542
+ vertices.append(cur_verts)
543
+ faces.append(cur_faces)
544
+ face_colors.append(cur_face_colors)
545
+
546
+ vertices = np.concatenate(vertices, axis=0)
547
+ faces = np.concatenate(faces, axis=0)
548
+ face_colors = np.concatenate(face_colors, axis=0)
549
+
550
+ if ground_plane is not None:
551
+ # compute transform between identity floor and passed in floor
552
+ a, b, c, d = ground_plane
553
+ # rotation
554
+ old_normal = np.array([0.0, 0.0, 1.0])
555
+ new_normal = np.array([a, b, c])
556
+ new_normal = new_normal / np.linalg.norm(new_normal)
557
+ v = np.cross(old_normal, new_normal)
558
+ ang_sin = np.linalg.norm(v)
559
+ ang_cos = np.dot(old_normal, new_normal)
560
+ skew_v = np.array(
561
+ [[0.0, -v[2], v[1]], [v[2], 0.0, -v[0]], [-v[1], v[0], 0.0]]
562
+ )
563
+ R = (
564
+ np.eye(3)
565
+ + skew_v
566
+ + np.matmul(skew_v, skew_v) * ((1.0 - ang_cos) / (ang_sin**2))
567
+ )
568
+ # translation
569
+ # project point of focus onto plane
570
+ if xyz_orig is None:
571
+ xyz_orig = np.array([0.0, 0.0, 0.0])
572
+ # project origin onto plane
573
+ plane_normal = np.array([a, b, c])
574
+ plane_off = d
575
+ direction = -plane_normal
576
+ s = (plane_off - np.dot(plane_normal, xyz_orig)) / np.dot(
577
+ plane_normal, direction
578
+ )
579
+ itsct_pt = xyz_orig + s * direction
580
+ t = itsct_pt
581
+
582
+ # transform floor
583
+ vertices = np.dot(R, vertices.T).T + t.reshape((1, 3))
584
+
585
+ ground_tri = trimesh.creation.Trimesh(
586
+ vertices=vertices, faces=faces, face_colors=face_colors, process=False
587
+ )
588
+ ground_mesh = self.pyrender.Mesh.from_trimesh(ground_tri, smooth=False)
589
+
590
+ self.acquire_render_lock()
591
+ anim_node = self.scene.add(ground_mesh, "ground-mesh")
592
+ self.release_render_lock()
593
+
594
+ # update light nodes (if using raymond lighting) to be in this frame
595
+ if ground_plane is not None:
596
+ for lnode in self.light_nodes:
597
+ new_lpose = np.eye(4)
598
+ new_lrot = np.dot(R, lnode.matrix[:3, :3])
599
+ new_ltrans = t
600
+ new_lpose[:3, :3] = new_lrot
601
+ new_lpose[:3, 3] = new_ltrans
602
+ self.acquire_render_lock()
603
+ self.scene.set_pose(lnode, new_lpose)
604
+ self.release_render_lock()
605
+
606
+ def update_frame(self):
607
+ """
608
+ Update frame to show the current self.animation_frame_idx
609
+ """
610
+ for seq_idx in range(len(self.animated_seqs)):
611
+ mesh_mean_list=[]
612
+ # for 2 meshes this is a list with 2 lists with all the frames inside
613
+ # import ipdb; ipdb.set_trace()
614
+ cur_mesh = self.animated_seqs[seq_idx][self.animation_frame_idx]
615
+
616
+ # render the current frame of eqch sequence
617
+ self.acquire_render_lock()
618
+
619
+ # replace the old mesh
620
+ anim_node = list(self.scene.get_nodes(name="anim-mesh-%2d" % (seq_idx)))
621
+ anim_node = anim_node[0]
622
+ anim_node.mesh = cur_mesh
623
+ # update camera pc-camera
624
+ if (
625
+ self.follow_camera and not self.use_intrins
626
+ ): # don't want to reset if we're going from camera view
627
+ if self.animated_seqs_type[seq_idx] == "mesh":
628
+ # import ipdb; ipdb.set_trace()
629
+ cam_node = list(self.scene.get_nodes(name="pc-camera"))
630
+ cam_node = cam_node[0]
631
+ if len(self.animated_seqs) > 1:
632
+ mesh_mean_paired=[]
633
+ for mesh_seq in self.animated_seqs:
634
+ mesh_mean_paired.append(mesh_seq[self.animation_frame_idx])
635
+ mesh_mean_paired = [ mesh_obj.primitives[0].positions[np.newaxis] for mesh_obj in mesh_mean_paired]
636
+ all_meshes_curr_frame = np.concatenate(mesh_mean_paired)
637
+ coord_to_add_camera = np.mean(all_meshes_curr_frame, axis=(0, 1))
638
+ else:
639
+ coord_to_add_camera = np.mean(cur_mesh.primitives[0].positions, axis=0)
640
+
641
+ camera_pose = self.get_init_cam_pose()
642
+ camera_pose[:3, 3] = camera_pose[:3, 3] + np.array(
643
+ [coord_to_add_camera[0], coord_to_add_camera[1]+1.5, 0.5]
644
+ )
645
+ self.scene.set_pose(cam_node, camera_pose)
646
+
647
+ self.release_render_lock()
648
+
649
+ # update background img
650
+ if self.img_seq is not None:
651
+ self.acquire_render_lock()
652
+ self.cur_bg_img = self.img_seq[self.animation_frame_idx]
653
+ self.release_render_lock
654
+
655
+ # update mask
656
+ if self.mask_seq is not None:
657
+ self.acquire_render_lock()
658
+ self.cur_mask = self.mask_seq[self.animation_frame_idx]
659
+ self.release_render_lock
660
+
661
+ def animate(self, fps=30, start=None, end=None, progress_bar=tqdm):
662
+ """
663
+ Starts animating any given mesh sequences. This should be called last after adding
664
+ all desired components to the scene as it is a blocking operation and will run
665
+ until the user exits (or the full video is rendered if offline).
666
+ """
667
+ if not self.use_offscreen:
668
+ print("=================================")
669
+ print("VIEWER CONTROLS")
670
+ print("p - pause/play")
671
+ print('"," and "." - step back/forward one frame')
672
+ print("w - wireframe")
673
+ print("h - render shadows")
674
+ print("q - quit")
675
+ print("=================================")
676
+
677
+ # print("Animating...")
678
+ # frame_dur = 1.0 / float(fps)
679
+
680
+ # set up init frame
681
+ self.update_frame()
682
+
683
+ # only support offscreen in this script
684
+ assert self.use_offscreen
685
+
686
+ assert self.animation_frame_idx == 0
687
+
688
+ iterator = range(self.animation_len)
689
+ if progress_bar is not None:
690
+ iterator = progress_bar(list(iterator), desc="SMPL rendering")
691
+
692
+ for frame_idx in iterator:
693
+ self.animation_frame_idx = frame_idx
694
+ # render frame
695
+ if not os.path.exists(self.render_path):
696
+ os.mkdir(self.render_path)
697
+ # print("Rendering frames to %s!" % (self.render_path))
698
+ cur_file_path = os.path.join(
699
+ self.render_path,
700
+ "frame_%08d.%s" % (self.animation_frame_idx, self.img_extn),
701
+ )
702
+ if start is None or end is None:
703
+ self.save_snapshot(cur_file_path)
704
+ else:
705
+ # only do it for a crop
706
+ if start <= self.animation_frame_idx < end:
707
+ self.save_snapshot(cur_file_path)
708
+
709
+ if self.animation_frame_idx + 1 >= self.animation_len:
710
+ continue # last iteration anyway
711
+
712
+ self.animation_render_time = time.time()
713
+ # if self.is_paused:
714
+ # self.update_frame() # just in case there's a single frame update
715
+ # continue
716
+
717
+ self.animation_frame_idx = self.animation_frame_idx + 1
718
+ # % self.animation_len
719
+ self.update_frame()
720
+
721
+ # if self.single_frame:
722
+ # break
723
+
724
+ self.animation_frame_idx = 0
725
+ return True
726
+
727
+ def _add_raymond_light(self):
728
+ DirectionalLight = self.pyrender.light.DirectionalLight
729
+ Node = self.pyrender.node.Node
730
+ # from pyrender.light import DirectionalLight
731
+ # from pyrender.node import Node
732
+
733
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
734
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
735
+
736
+ nodes = []
737
+
738
+ for phi, theta in zip(phis, thetas):
739
+ xp = np.sin(theta) * np.cos(phi)
740
+ yp = np.sin(theta) * np.sin(phi)
741
+ zp = np.cos(theta)
742
+
743
+ z = np.array([xp, yp, zp])
744
+ z = z / np.linalg.norm(z)
745
+ x = np.array([-z[1], z[0], 0.0])
746
+ if np.linalg.norm(x) == 0:
747
+ x = np.array([1.0, 0.0, 0.0])
748
+ x = x / np.linalg.norm(x)
749
+ y = np.cross(z, x)
750
+
751
+ matrix = np.eye(4)
752
+ matrix[:3, :3] = np.c_[x, y, z]
753
+ nodes.append(
754
+ Node(
755
+ light=DirectionalLight(color=np.ones(3), intensity=1.0),
756
+ matrix=matrix,
757
+ )
758
+ )
759
+ return nodes
760
+
761
+ def use_raymond_lighting(self, intensity=1.0):
762
+ if not self.use_offscreen:
763
+ sys.stderr.write("Interactive viewer already uses raymond lighting!\n")
764
+ return
765
+ for n in self._add_raymond_light():
766
+ n.light.intensity = intensity / 3.0
767
+ if not self.scene.has_node(n):
768
+ self.scene.add_node(n) # , parent_node=pc)
769
+
770
+ self.light_nodes.append(n)
771
+
772
+ def set_render_settings(
773
+ self, wireframe=None, RGBA=None, out_path=None, single_frame=None
774
+ ):
775
+ if wireframe is not None and wireframe == True:
776
+ self.render_wireframe = True
777
+ if RGBA is not None and RGBA == True:
778
+ self.render_RGBA = True
779
+ if out_path is not None:
780
+ self.render_path = out_path
781
+ if single_frame is not None:
782
+ self.single_frame = single_frame
783
+
784
+ def render(self):
785
+ RenderFlags = self.pyrender.constants.RenderFlags
786
+ # from pyrender.constants import RenderFlags
787
+
788
+ flags = RenderFlags.SHADOWS_DIRECTIONAL
789
+ if self.render_RGBA:
790
+ flags |= RenderFlags.RGBA
791
+ if self.render_wireframe:
792
+ flags |= RenderFlags.ALL_WIREFRAME
793
+ color_img, depth_img = self.viewer.render(self.scene, flags=flags)
794
+
795
+ output_img = color_img
796
+ if self.cur_bg_img is not None:
797
+ color_img = color_img.astype(np.float32) / 255.0
798
+ person_mask = None
799
+ if self.cur_mask is not None:
800
+ person_mask = self.cur_mask[:, :, np.newaxis]
801
+ color_img = color_img * (1.0 - person_mask)
802
+ valid_mask = (color_img[:, :, -1] > 0)[:, :, np.newaxis]
803
+ input_img = self.cur_bg_img
804
+ if color_img.shape[2] == 4:
805
+ output_img = (
806
+ color_img[:, :, :-1] * color_img[:, :, 3:]
807
+ + (1.0 - color_img[:, :, 3:]) * input_img
808
+ )
809
+ else:
810
+ output_img = (
811
+ color_img[:, :, :-1] * valid_mask + (1 - valid_mask) * input_img
812
+ )
813
+
814
+ output_img = (output_img * 255.0).astype(np.uint8)
815
+
816
+ return output_img
817
+
818
+ def save_snapshot(self, fname):
819
+ if not self.use_offscreen:
820
+ sys.stderr.write(
821
+ "Currently saving snapshots only works with off-screen renderer!\n"
822
+ )
823
+ return
824
+ color_img = self.render()
825
+ if color_img.shape[-1] == 4:
826
+ img_bgr = cv2.cvtColor(color_img, cv2.COLOR_RGBA2BGRA)
827
+ else:
828
+ img_bgr = cv2.cvtColor(color_img, cv2.COLOR_RGB2BGR)
829
+ cv2.imwrite(fname, img_bgr, COMPRESS_PARAMS)
830
+
831
+ def acquire_render_lock(self):
832
+ if not self.use_offscreen:
833
+ self.viewer.render_lock.acquire()
834
+
835
+ def release_render_lock(self):
836
+ if not self.use_offscreen:
837
+ self.viewer.render_lock.release()
renderer/humor_render_tools/parameters.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from HuMoR
2
+ # fmt: off
3
+ smpl_connections = [[11, 8], [8, 5], [5, 2], [2, 0], [10, 7], [7, 4], [4, 1], [1, 0],
4
+ [0, 3], [3, 6], [6, 9], [9, 12], [12, 15], [12, 13], [13, 16], [16, 18],
5
+ [18, 20], [12, 14], [14, 17], [17, 19], [19, 21]]
6
+
7
+ VERTEX_COLOR = [0.0390625, 0.4140625, 0.796875]
8
+
9
+ colors = {
10
+ 'pink': [.7, .7, .9],
11
+ 'purple': [.9, .7, .7],
12
+ 'cyan': [.7, .75, .5],
13
+ 'red': [1.0, 0.0, 0.0],
14
+ 'green': [.0, 1., .0],
15
+ 'yellow': [1., 1., 0],
16
+ 'brown': [.5, .7, .7],
17
+ 'blue': [.0, .0, 1.],
18
+ 'offwhite': [.8, .9, .9],
19
+ 'orange': [.5, .65, .9],
20
+ 'grey': [.7, .7, .7],
21
+ 'black': [0.0, 0.0, 0.0],
22
+ 'white': [1.0, 1.0, 1.0],
23
+ 'yellowg': [0.83, 1.0, 0.0],
24
+ # 'teal':[0.0, 0.7, 0.6],
25
+ # 'neon pink': (1.0, 0.07, 0.58),
26
+ 'sky blue': [0.1, 0.5, 0.75],
27
+ 'teal': [0.0, 0.6, 0.5],
28
+ 'neon pink': [0.8, 0.05, 0.45],
29
+
30
+ "vertex": VERTEX_COLOR
31
+ }
32
+ # fmt: on
renderer/humor_render_tools/smplh.faces ADDED
Binary file (165 kB). View file
 
renderer/humor_render_tools/tools.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from HuMoR
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import trimesh
6
+ from .parameters import colors, smpl_connections
7
+ from .mesh_viewer import MeshViewer
8
+
9
+ c2c = lambda tensor: tensor.detach().cpu().numpy() # noqa
10
+
11
+
12
+ def viz_smpl_seq(
13
+ pyrender,
14
+ out_path,
15
+ body,
16
+ #
17
+ start=None,
18
+ end=None,
19
+ #
20
+ imw=720,
21
+ imh=720,
22
+ fps=20,
23
+ use_offscreen=True,
24
+ follow_camera=True,
25
+ progress_bar=tqdm,
26
+ #
27
+ contacts=None,
28
+ render_body=True,
29
+ render_joints=False,
30
+ render_skeleton=False,
31
+ render_ground=True,
32
+ ground_plane=None,
33
+ wireframe=False,
34
+ RGBA=False,
35
+ joints_seq=None,
36
+ joints_vel=None,
37
+ vtx_list=None,
38
+ points_seq=None,
39
+ points_vel=None,
40
+ static_meshes=None,
41
+ camera_intrinsics=None,
42
+ img_seq=None,
43
+ point_rad=0.015,
44
+ skel_connections=smpl_connections,
45
+ img_extn="png",
46
+ ground_alpha=1.0,
47
+ body_alpha=None,
48
+ mask_seq=None,
49
+ cam_offset=[0.0, 2.2, 0.9], # [0.0, 4.0, 1.25],
50
+ ground_color0=[0.8, 0.9, 0.9],
51
+ ground_color1=[0.6, 0.7, 0.7],
52
+ # ground_color0=[1.0, 1.0, 1.0],
53
+ # ground_color1=[0.0, 0.0, 0.0],
54
+ skel_color=[0.5, 0.5, 0.5], # [0.0, 0.0, 1.0],
55
+ joint_rad=0.015,
56
+ point_color=[0.0, 0.0, 1.0],
57
+ joint_color=[0.0, 1.0, 0.0],
58
+ contact_color=[1.0, 0.0, 0.0],
59
+ vertex_color=[],
60
+ vertex_color_list = [],
61
+ # color_vtx0 = colors['neon pink']
62
+ # color_vtx1 = colors['sky blue']
63
+ render_bodies_static=None,
64
+ render_points_static=None,
65
+ cam_rot=None,
66
+ ):
67
+ """
68
+ Visualizes the body model output of a smpl sequence.
69
+ - body : body model output from SMPL forward pass (where the sequence is the batch)
70
+ - joints_seq : list of torch/numy tensors/arrays
71
+ - points_seq : list of torch/numpy tensors
72
+ - camera_intrinsics : (fx, fy, cx, cy)
73
+ - ground_plane : [a, b, c, d]
74
+ - render_bodies_static is an integer, if given renders all bodies at once but only every x steps
75
+ """
76
+
77
+ if contacts is not None and torch.is_tensor(contacts):
78
+ contacts = c2c(contacts)
79
+ # import ipdb; ipdb.set_trace()
80
+ if isinstance(body, list):
81
+ render_pair = True
82
+ else:
83
+ render_pair = False
84
+
85
+ if render_body or vtx_list is not None:
86
+ if render_pair:
87
+ nv = body[0].v.size(1)
88
+ else:
89
+ nv = body.v.size(1)
90
+ vertex_colors = np.tile(vertex_color, (nv, 1))
91
+
92
+ if body_alpha is not None:
93
+ vtx_alpha = np.ones((vertex_colors.shape[0], 1)) * body_alpha
94
+ vertex_colors = np.concatenate([vertex_colors, vtx_alpha], axis=1)
95
+ if render_pair:
96
+ faces = c2c(body[0].f)
97
+ else:
98
+ faces = c2c(body.f)
99
+
100
+ if not vertex_color_list:
101
+ vertex_color_list= ['neon pink','sky blue']
102
+
103
+ if not vertex_color:
104
+ vertex_color=colors['sky blue']
105
+ else:
106
+ vertex_color=colors[vertex_color[0]]
107
+
108
+ if render_pair:
109
+
110
+ color_vtx0=colors[vertex_color_list[0]]
111
+ color_vtx1=colors[vertex_color_list[1]]
112
+
113
+
114
+ body_mesh_seq0 = [
115
+ trimesh.Trimesh(
116
+ vertices=c2c(body[0].v[i]),
117
+ faces=faces,
118
+ vertex_colors=color_vtx0,
119
+ process=False,
120
+ )
121
+ for i in range(body[0].v.size(0))
122
+ ]
123
+ body_mesh_seq1 = [
124
+ trimesh.Trimesh(
125
+ vertices=c2c(body[1].v[i]),
126
+ faces=faces,
127
+ vertex_colors=color_vtx1,
128
+ process=False,
129
+ )
130
+ for i in range(body[1].v.size(0))
131
+ ]
132
+ else:
133
+ body_mesh_seq = [
134
+ trimesh.Trimesh(
135
+ vertices=c2c(body.v[i]),
136
+ faces=faces,
137
+ vertex_colors=vertex_color,
138
+ process=False,
139
+ )
140
+ for i in range(body.v.size(0))
141
+ ]
142
+
143
+ if render_joints and joints_seq is None:
144
+ # only body joints
145
+ joints_seq = [c2c(body.Jtr[i, :22]) for i in range(body.Jtr.size(0))]
146
+ elif render_joints and torch.is_tensor(joints_seq[0]):
147
+ joints_seq = [c2c(joint_frame) for joint_frame in joints_seq]
148
+
149
+ if joints_vel is not None and torch.is_tensor(joints_vel[0]):
150
+ joints_vel = [c2c(joint_frame) for joint_frame in joints_vel]
151
+ if points_vel is not None and torch.is_tensor(points_vel[0]):
152
+ points_vel = [c2c(joint_frame) for joint_frame in points_vel]
153
+ # cam_offset = [0.0, 2.2, 0.2]
154
+ mv = MeshViewer(
155
+ pyrender,
156
+ width=imw,
157
+ height=imh,
158
+ use_offscreen=use_offscreen,
159
+ follow_camera=follow_camera,
160
+ camera_intrinsics=camera_intrinsics,
161
+ img_extn=img_extn,
162
+ default_cam_offset=cam_offset,
163
+ default_cam_rot=cam_rot,
164
+ )
165
+ if render_body and render_bodies_static is None:
166
+ if render_pair:
167
+ mv.add_mesh_seq(body_mesh_seq0, progress_bar=progress_bar)
168
+ mv.add_mesh_seq(body_mesh_seq1, progress_bar=progress_bar)
169
+ else:
170
+ mv.add_mesh_seq(body_mesh_seq, progress_bar=progress_bar)
171
+ elif render_body and render_bodies_static is not None:
172
+ if render_pair:
173
+ mv.add_static_meshes(
174
+ [
175
+ body_mesh_seq0[i]
176
+ for i in range(len(body_mesh_seq0))
177
+ if i % render_bodies_static == 0
178
+ ]
179
+ )
180
+ mv.add_static_meshes(
181
+ [
182
+ body_mesh_seq1[i]
183
+ for i in range(len(body_mesh_seq1))
184
+ if i % render_bodies_static == 0
185
+ ]
186
+ )
187
+
188
+ else:
189
+ mv.add_static_meshes(
190
+ [
191
+ body_mesh_seq[i]
192
+ for i in range(len(body_mesh_seq))
193
+ if i % render_bodies_static == 0
194
+ ]
195
+ )
196
+ if render_joints and render_skeleton:
197
+ mv.add_point_seq(
198
+ joints_seq,
199
+ color=joint_color,
200
+ radius=joint_rad,
201
+ contact_seq=contacts,
202
+ connections=skel_connections,
203
+ connect_color=skel_color,
204
+ vel=joints_vel,
205
+ contact_color=contact_color,
206
+ render_static=render_points_static,
207
+ )
208
+ elif render_joints:
209
+ mv.add_point_seq(
210
+ joints_seq,
211
+ color=joint_color,
212
+ radius=joint_rad,
213
+ contact_seq=contacts,
214
+ vel=joints_vel,
215
+ contact_color=contact_color,
216
+ render_static=render_points_static,
217
+ )
218
+
219
+ if vtx_list is not None:
220
+
221
+ mv.add_smpl_vtx_list_seq(
222
+ body_mesh_seq, vtx_list, color=[0.0, 0.0, 1.0], radius=0.015
223
+ )
224
+
225
+ if points_seq is not None:
226
+ if torch.is_tensor(points_seq[0]):
227
+ points_seq = [c2c(point_frame) for point_frame in points_seq]
228
+ mv.add_point_seq(
229
+ points_seq,
230
+ color=point_color,
231
+ radius=point_rad,
232
+ vel=points_vel,
233
+ render_static=render_points_static,
234
+ )
235
+
236
+ if static_meshes is not None:
237
+ mv.set_static_meshes(static_meshes)
238
+
239
+ if img_seq is not None:
240
+ mv.set_img_seq(img_seq)
241
+
242
+ if mask_seq is not None:
243
+ mv.set_mask_seq(mask_seq)
244
+
245
+ if render_ground:
246
+ xyz_orig = None
247
+ if ground_plane is not None:
248
+ if render_body:
249
+ if render_pair:
250
+ xyz_orig = (body_mesh_seq0[0].vertices[0, :] + body_mesh_seq1[0].vertices[0, :]) / 2
251
+ else:
252
+ xyz_orig = body_mesh_seq[0].vertices[0, :]
253
+
254
+ elif render_joints:
255
+ xyz_orig = joints_seq[0][0, :]
256
+ elif points_seq is not None:
257
+ xyz_orig = points_seq[0][0, :]
258
+
259
+ mv.add_ground(
260
+ ground_plane=ground_plane,
261
+ xyz_orig=xyz_orig,
262
+ color0=ground_color0,
263
+ color1=ground_color1,
264
+ alpha=ground_alpha,
265
+ )
266
+
267
+ mv.set_render_settings(
268
+ out_path=out_path,
269
+ wireframe=wireframe,
270
+ RGBA=RGBA,
271
+ single_frame=(
272
+ render_points_static is not None or render_bodies_static is not None
273
+ ),
274
+ ) # only does anything for offscreen rendering
275
+ try:
276
+ mv.animate(fps=fps, start=start, end=end, progress_bar=progress_bar)
277
+ except RuntimeError as err:
278
+ print("Could not render properly with the error: %s" % (str(err)))
279
+
280
+ del mv
renderer/matplotlib.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From TEMOS: temos/render/anim.py
2
+ # Inspired by
3
+ # - https://github.com/anindita127/Complextext2animation/blob/main/src/utils/visualization.py
4
+ # - https://github.com/facebookresearch/QuaterNet/blob/main/common/visualization.py
5
+
6
+ import os
7
+ import logging
8
+
9
+ from dataclasses import dataclass
10
+ from typing import List, Tuple, Optional
11
+ import numpy as np
12
+ from src.tools.rifke import canonicalize_rotation
13
+
14
+ logger = logging.getLogger("matplotlib.animation")
15
+ logger.setLevel(logging.ERROR)
16
+
17
+ colors = ("black", "magenta", "red", "green", "blue")
18
+
19
+ KINEMATIC_TREES = {
20
+ "smpljoints": [
21
+ [0, 3, 6, 9, 12, 15],
22
+ [9, 13, 16, 18, 20],
23
+ [9, 14, 17, 19, 21],
24
+ [0, 1, 4, 7, 10],
25
+ [0, 2, 5, 8, 11],
26
+ ],
27
+ "guoh3djoints": [ # no hands
28
+ [0, 3, 6, 9, 12, 15],
29
+ [9, 13, 16, 18, 20],
30
+ [9, 14, 17, 19, 21],
31
+ [0, 1, 4, 7, 10],
32
+ [0, 2, 5, 8, 11],
33
+ ],
34
+ }
35
+
36
+
37
+ @dataclass
38
+ class MatplotlibRender:
39
+ jointstype: str = "smpljoints"
40
+ fps: float = 20.0
41
+ colors: List[str] = colors
42
+ figsize: int = 4
43
+ fontsize: int = 15
44
+ canonicalize: bool = False
45
+
46
+ def __call__(
47
+ self,
48
+ joints,
49
+ output,
50
+ fps=None,
51
+ highlights=None,
52
+ title: str = "",
53
+ canonicalize=None,
54
+ ):
55
+ canonicalize = canonicalize if canonicalize is not None else self.canonicalize
56
+ fps = fps if fps is not None else self.fps
57
+ if joints.shape[1] == 24:
58
+ # remove the hands
59
+ joints = joints[:, :22]
60
+
61
+ render_animation(
62
+ joints,
63
+ title=title,
64
+ highlights=highlights,
65
+ output=output,
66
+ jointstype=self.jointstype,
67
+ fps=self.fps,
68
+ colors=self.colors,
69
+ figsize=(self.figsize, self.figsize),
70
+ fontsize=self.fontsize,
71
+ canonicalize=canonicalize,
72
+ )
73
+
74
+
75
+ def init_axis(fig, title, radius=1.5):
76
+ ax = fig.add_subplot(1, 1, 1, projection="3d")
77
+ ax.view_init(elev=20.0, azim=-60)
78
+
79
+ fact = 2
80
+ ax.set_xlim3d([-radius / fact, radius / fact])
81
+ ax.set_ylim3d([-radius / fact, radius / fact])
82
+ ax.set_zlim3d([0, radius])
83
+
84
+ ax.set_aspect("auto")
85
+ ax.set_xticklabels([])
86
+ ax.set_yticklabels([])
87
+ ax.set_zticklabels([])
88
+
89
+ ax.set_axis_off()
90
+ ax.grid(b=False)
91
+
92
+ ax.set_title(title, loc="center", wrap=True)
93
+ return ax
94
+
95
+
96
+ def plot_floor(ax, minx, maxx, miny, maxy, minz):
97
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
98
+
99
+ # Plot a plane XZ
100
+ verts = [
101
+ [minx, miny, minz],
102
+ [minx, maxy, minz],
103
+ [maxx, maxy, minz],
104
+ [maxx, miny, minz],
105
+ ]
106
+ xz_plane = Poly3DCollection([verts], zorder=1)
107
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 1))
108
+ ax.add_collection3d(xz_plane)
109
+
110
+ # Plot a bigger square plane XZ
111
+ radius = max((maxx - minx), (maxy - miny))
112
+
113
+ # center +- radius
114
+ minx_all = (maxx + minx) / 2 - radius
115
+ maxx_all = (maxx + minx) / 2 + radius
116
+
117
+ miny_all = (maxy + miny) / 2 - radius
118
+ maxy_all = (maxy + miny) / 2 + radius
119
+
120
+ verts = [
121
+ [minx_all, miny_all, minz],
122
+ [minx_all, maxy_all, minz],
123
+ [maxx_all, maxy_all, minz],
124
+ [maxx_all, miny_all, minz],
125
+ ]
126
+ xz_plane = Poly3DCollection([verts], zorder=1)
127
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
128
+ ax.add_collection3d(xz_plane)
129
+ return ax
130
+
131
+
132
+ def update_camera(ax, root, radius=1.5):
133
+ fact = 2
134
+ ax.set_xlim3d([-radius / fact + root[0], radius / fact + root[0]])
135
+ ax.set_ylim3d([-radius / fact + root[1], radius / fact + root[1]])
136
+
137
+
138
+ def render_animation(
139
+ joints: np.ndarray,
140
+ output: str = "notebook",
141
+ highlights: Optional[np.ndarray] = None,
142
+ jointstype: str = "smpljoints",
143
+ title: str = "",
144
+ fps: float = 20.0,
145
+ colors: List[str] = colors,
146
+ figsize: Tuple[int] = (4, 4),
147
+ fontsize: int = 15,
148
+ canonicalize: bool = False,
149
+ agg=True,
150
+ ):
151
+ if agg:
152
+ import matplotlib
153
+
154
+ matplotlib.use("Agg")
155
+
156
+ if highlights is not None:
157
+ assert len(highlights) == len(joints)
158
+
159
+ assert jointstype in KINEMATIC_TREES
160
+ kinematic_tree = KINEMATIC_TREES[jointstype]
161
+
162
+ import matplotlib.pyplot as plt
163
+ from matplotlib.animation import FuncAnimation
164
+ import matplotlib.patheffects as pe
165
+
166
+ mean_fontsize = fontsize
167
+
168
+ # heuristic to change fontsize
169
+ fontsize = mean_fontsize - (len(title) - 30) / 20
170
+ plt.rcParams.update({"font.size": fontsize})
171
+
172
+ # Z is gravity here
173
+ x, y, z = 0, 1, 2
174
+
175
+ joints = joints.copy()
176
+
177
+ if canonicalize:
178
+ joints = canonicalize_rotation(joints, jointstype=jointstype)
179
+
180
+ # Create a figure and initialize 3d plot
181
+ fig = plt.figure(figsize=figsize)
182
+ ax = init_axis(fig, title)
183
+
184
+ # Create spline line
185
+ trajectory = joints[:, 0, [x, y]]
186
+ avg_segment_length = (
187
+ np.mean(np.linalg.norm(np.diff(trajectory, axis=0), axis=1)) + 1e-3
188
+ )
189
+ draw_offset = int(25 / avg_segment_length)
190
+ (spline_line,) = ax.plot(*trajectory.T, zorder=10, color="white")
191
+
192
+ # Create a floor
193
+ minx, miny, _ = joints.min(axis=(0, 1))
194
+ maxx, maxy, _ = joints.max(axis=(0, 1))
195
+ plot_floor(ax, minx, maxx, miny, maxy, 0)
196
+
197
+ # Put the character on the floor
198
+ height_offset = np.min(joints[:, :, z]) # Min height
199
+ joints = joints.copy()
200
+ joints[:, :, z] -= height_offset
201
+
202
+ # Initialization for redrawing
203
+ lines = []
204
+ initialized = False
205
+
206
+ def update(frame):
207
+ nonlocal initialized
208
+ skeleton = joints[frame]
209
+
210
+ root = skeleton[0]
211
+ update_camera(ax, root)
212
+
213
+ hcolors = colors
214
+ if highlights is not None and highlights[frame]:
215
+ hcolors = ("red", "red", "red", "red", "red")
216
+
217
+ for index, (chain, color) in enumerate(
218
+ zip(reversed(kinematic_tree), reversed(hcolors))
219
+ ):
220
+ if not initialized:
221
+ lines.append(
222
+ ax.plot(
223
+ skeleton[chain, x],
224
+ skeleton[chain, y],
225
+ skeleton[chain, z],
226
+ linewidth=6.0,
227
+ color=color,
228
+ zorder=20,
229
+ path_effects=[pe.SimpleLineShadow(), pe.Normal()],
230
+ )
231
+ )
232
+
233
+ else:
234
+ lines[index][0].set_xdata(skeleton[chain, x])
235
+ lines[index][0].set_ydata(skeleton[chain, y])
236
+ lines[index][0].set_3d_properties(skeleton[chain, z])
237
+ lines[index][0].set_color(color)
238
+
239
+ left = max(frame - draw_offset, 0)
240
+ right = min(frame + draw_offset, trajectory.shape[0])
241
+
242
+ spline_line.set_xdata(trajectory[left:right, 0])
243
+ spline_line.set_ydata(trajectory[left:right, 1])
244
+ spline_line.set_3d_properties(np.zeros_like(trajectory[left:right, 0]))
245
+ initialized = True
246
+
247
+ fig.tight_layout()
248
+ frames = joints.shape[0]
249
+ anim = FuncAnimation(fig, update, frames=frames, interval=1000 / fps, repeat=False)
250
+
251
+ if output == "notebook":
252
+ from IPython.display import HTML
253
+
254
+ HTML(anim.to_jshtml())
255
+ else:
256
+ # anim.save(output, writer='ffmpeg', fps=fps)
257
+ anim.save(output, fps=fps)
258
+
259
+ plt.close()
renderer/utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transform3d import transform_body_pose
2
+
3
+ def run_smpl_fwd_vertices(body_model,
4
+ body_transl,
5
+ body_orient, body_pose):
6
+ """
7
+ Standalone function to run SMPL forward pass.run_smpl_fwd_vertices
8
+
9
+ Parameters:
10
+ - body_model: The SMPL model instance
11
+ - body_transl: Translation tensor
12
+ - body_orient: Orientation tensor
13
+ - body_pose: Pose tensor
14
+ - fast: Boolean flag to use fast version
15
+
16
+ Returns:
17
+ - Vertices from the SMPL forward pass
18
+ """
19
+
20
+ if len(body_transl.shape) > 2:
21
+ body_transl = body_transl.flatten(0, 1)
22
+ body_orient = body_orient.flatten(0, 1)
23
+ body_pose = body_pose.flatten(0, 1)
24
+
25
+ batch_size = body_transl.shape[0]
26
+ body_model.batch_size = batch_size
27
+
28
+ return body_model(
29
+ transl=body_transl,
30
+ body_pose=transform_body_pose(body_pose, 'aa->rot'),
31
+ global_orient=transform_body_pose(body_orient, 'aa->rot')
32
+ )
33
+
renderer/video.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import moviepy.editor as mp
2
+
3
+ # import moviepy.video.fx.all as vfx
4
+ import os
5
+
6
+
7
+ class Video:
8
+ def __init__(self, frame_path: str, fps: float = 20.0, res="high"):
9
+ frame_path = str(frame_path)
10
+ self.fps = fps
11
+
12
+ self._conf = {
13
+ "codec": "libx264",
14
+ "fps": self.fps,
15
+ "audio_codec": "aac",
16
+ "temp_audiofile": "temp-audio.m4a",
17
+ "remove_temp": True,
18
+ }
19
+
20
+ if res == "low":
21
+ bitrate = "500k"
22
+ else:
23
+ bitrate = "5000k"
24
+
25
+ self._conf = {"bitrate": bitrate, "fps": self.fps}
26
+
27
+ # Load video
28
+ # video = mp.VideoFileClip(video1_path, audio=False)
29
+ # Load with frames
30
+ frames = [os.path.join(frame_path, x) for x in sorted(os.listdir(frame_path))]
31
+ video = mp.ImageSequenceClip(frames, fps=fps)
32
+ self.video = video
33
+ self.duration = video.duration
34
+
35
+ def save(self, out_path):
36
+ out_path = str(out_path)
37
+ self.video.subclip(0, self.duration).write_videofile(
38
+ out_path, verbose=False, logger=None, **self._conf
39
+ )
requirements.txt CHANGED
@@ -5,4 +5,5 @@ transformers==4.41.2
5
  hydra-core
6
  einops
7
  roma
8
- aitviewer
 
 
5
  hydra-core
6
  einops
7
  roma
8
+ pyrender
9
+ moviepy