Spaces:
Running
Running
| import torch | |
| from transform3d import transform_body_pose, apply_rot_delta, get_z_rot, change_for | |
| def diffout2motion(diffout, normalizer): | |
| # - "body_transl_delta_pelv_xy_wo_z" | |
| # - "body_transl_z" | |
| # - "z_orient_delta" | |
| # - "body_orient_xy" | |
| # - "body_pose" | |
| # - "body_joints_local_wo_z_rot" | |
| feats_unnorm = normalizer.cat_inputs(normalizer.unnorm_inputs( | |
| normalizer.uncat_inputs(diffout, | |
| normalizer.input_feats_dims), | |
| normalizer.input_feats))[0] | |
| # FIRST POSE FOR GENERATION & DELTAS FOR INTEGRATION | |
| if "body_joints_local_wo_z_rot" in normalizer.input_feats: | |
| idx = normalizer.input_feats.index("body_joints_local_wo_z_rot") | |
| feats_unnorm = feats_unnorm[..., :-normalizer.input_feats_dims[idx]] | |
| first_trans = torch.zeros(*diffout.shape[:-1], 3, | |
| device='cuda')[:, [0]] | |
| if 'z_orient_delta' in normalizer.input_feats: | |
| first_orient_z = torch.eye(3, device='cuda').unsqueeze(0) # Now the shape is (1, 1, 3, 3) | |
| first_orient_z = first_orient_z.repeat(feats_unnorm.shape[0], 1, 1) # Now the shape is (B, 1, 3, 3) | |
| first_orient_z = transform_body_pose(first_orient_z, 'rot->6d') | |
| # --> first_orient_z convert to 6d | |
| # integrate z orient delta --> z component tof orientation | |
| z_orient_delta = feats_unnorm[..., 9:15] | |
| prev_z = first_orient_z | |
| full_z_angle = [first_orient_z[:, None]] | |
| for i in range(1, z_orient_delta.shape[1]): | |
| curr_z = apply_rot_delta(prev_z, z_orient_delta[:, i]) | |
| prev_z = curr_z.clone() | |
| full_z_angle.append(curr_z[:,None]) | |
| full_z_angle = torch.cat(full_z_angle, dim=1) | |
| full_z_angle_rotmat = get_z_rot(full_z_angle) | |
| # full_orient = torch.cat([full_z_angle, xy_orient], dim=-1) | |
| xy_orient = feats_unnorm[..., 3:9] | |
| xy_orient_rotmat = transform_body_pose(xy_orient, '6d->rot') | |
| # xy_orient = remove_z_rot(xy_orient, in_format="6d") | |
| # GLOBAL ORIENTATION | |
| # full_z_angle = transform_body_pose(full_z_angle_rotmat, | |
| # 'rot->6d') | |
| # full_global_orient = apply_rot_delta(full_z_angle, | |
| # xy_orient) | |
| full_global_orient_rotmat = full_z_angle_rotmat @ xy_orient_rotmat | |
| full_global_orient = transform_body_pose(full_global_orient_rotmat, | |
| 'rot->6d') | |
| first_trans = normalizer.cat_inputs(normalizer.unnorm_inputs( | |
| [first_trans], | |
| ['body_transl']) | |
| )[0] | |
| # apply deltas | |
| # get velocity in global c.f. and add it to the state position | |
| assert 'body_transl_delta_pelv' in normalizer.input_feats | |
| pelvis_delta = feats_unnorm[..., :3] | |
| trans_vel_pelv = change_for(pelvis_delta[:, 1:], | |
| full_global_orient_rotmat[:, :-1], | |
| forward=False) | |
| # new_state_pos = prev_trans_norm.squeeze() + trans_vel_pelv | |
| full_trans = torch.cumsum(trans_vel_pelv, dim=1) + first_trans | |
| full_trans = torch.cat([first_trans, full_trans], dim=1) | |
| # "body_transl_delta_pelv_xy_wo_z" | |
| # first_trans = self.cat_inputs(self.unnorm_inputs( | |
| # [first_trans], | |
| # ['body_transl']) | |
| # )[0] | |
| # pelvis_xy = pelvis_delta_xy | |
| # FULL TRANSLATION | |
| # full_trans = torch.cat([pelvis_xy, | |
| # feats_unnorm[..., 2:3][:,1:]], dim=-1) | |
| ############# | |
| full_rots = torch.cat([full_global_orient, | |
| feats_unnorm[...,-21*6:]], | |
| dim=-1) | |
| full_motion_unnorm = torch.cat([full_trans, | |
| full_rots], dim=-1) | |
| return full_motion_unnorm | |