Spaces:
Paused
Paused
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import math | |
| import time | |
| import torch | |
| import argparse | |
| import numpy as np | |
| from tqdm import tqdm | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| from accelerate.utils import DistributedDataParallelKwargs | |
| from skeleton_models.skeletongen import SkeletonGPT | |
| from utils.skeleton_data_loader import SkeletonData | |
| from utils.save_utils import save_mesh, pred_joints_and_bones, save_skeleton_to_txt, save_skeleton_to_txt_joint, save_args, \ | |
| merge_duplicate_joints_and_fix_bones, save_skeleton_obj, render_mesh_with_skeleton | |
| from utils.eval_utils import chamfer_dist, joint2bone_chamfer_dist, bone2bone_chamfer_dist | |
| def get_args(): | |
| parser = argparse.ArgumentParser("SkeletonGPT", add_help=False) | |
| parser.add_argument("--input_pc_num", default=8192, type=int) | |
| parser.add_argument("--num_beams", default=1, type=int) | |
| parser.add_argument('--llm', default="facebook/opt-350m", type=str, help="The LLM backend") | |
| parser.add_argument("--pad_id", default=-1, type=int, help="padding id") | |
| parser.add_argument("--n_discrete_size", default=128, type=int, help="size of discretized 3D space") | |
| parser.add_argument("--n_max_bones", default=100, type=int, help="max number of bones") | |
| parser.add_argument('--dataset_path', default="Articulation_xlv2.npz", type=str, help="data path") | |
| parser.add_argument("--output_dir", default="outputs", type=str) | |
| parser.add_argument('--save_name', default="infer_results", type=str) | |
| parser.add_argument("--save_render", default=False, action="store_true", help="save rendering results of mesh with skel") | |
| parser.add_argument("--seed", default=0, type=int) | |
| parser.add_argument("--precision", default="fp16", type=str) | |
| parser.add_argument("--batchsize_per_gpu", default=1, type=int) | |
| parser.add_argument('--pretrained_weights', default=None, type=str, help="path of pretrained models") | |
| parser.add_argument("--hier_order", default=False, action="store_true", help="use hier order") | |
| parser.add_argument("--joint_token", default=False, action="store_true", help="use joint_based tokenization") | |
| parser.add_argument("--seq_shuffle", default=False, action="store_true", help="shuffle the skeleton sequence") | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == "__main__": | |
| args = get_args() | |
| dataset = SkeletonData.load(args, is_training=False) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=1, | |
| drop_last = False, | |
| shuffle = False, | |
| ) | |
| kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
| accelerator = Accelerator( | |
| kwargs_handlers=[kwargs], | |
| mixed_precision=args.precision, | |
| ) | |
| model = SkeletonGPT(args).cuda() | |
| if args.pretrained_weights is not None: | |
| pkg = torch.load(args.pretrained_weights, map_location=torch.device("cpu")) | |
| model.load_state_dict(pkg["model"]) | |
| else: | |
| raise ValueError("Pretrained weights must be provided.") | |
| set_seed(args.seed) | |
| dataloader, model = accelerator.prepare( | |
| dataloader, | |
| model, | |
| ) | |
| model.eval() | |
| output_dir = f'{args.output_dir}/{args.save_name}' | |
| print(output_dir) | |
| os.makedirs(output_dir, exist_ok=True) | |
| save_args(args, output_dir) | |
| gt_samples, pred_samples = [], [] | |
| avg_j2j_cd, avg_j2b_cd, avg_b2b_cd = 0.0, 0.0, 0.0 | |
| infer_all_time = [] | |
| num_valid = 0 | |
| results_file = f'{output_dir}/evaluate_results.txt' | |
| for curr_iter, batch_data_label in tqdm(enumerate(dataloader), total=len(dataloader)): | |
| start_time = time.time() | |
| with accelerator.autocast(): | |
| pred_bone_coords = model.generate(batch_data_label) | |
| infer_time_pre_mesh = time.time() - start_time | |
| infer_all_time.append(infer_time_pre_mesh) | |
| if pred_bone_coords is None: | |
| continue | |
| print(pred_bone_coords.shape) | |
| if pred_bone_coords.shape[1] > 0: | |
| gt_joints = batch_data_label['joints'].squeeze(0).cpu().numpy() | |
| gt_bones = batch_data_label['bones'].squeeze(0).cpu().numpy() | |
| pred_joints, pred_bones = pred_joints_and_bones(pred_bone_coords.cpu().numpy().squeeze(0)) | |
| if pred_bones.shape[0] == 0: | |
| continue | |
| # Post process: merge duplicate or nearby joints and deduplicate bones. | |
| if args.hier_order: # for MagicArticulate hier order | |
| pred_root_index = pred_bones[0][0] | |
| pred_joints, pred_bones, pred_root_index = merge_duplicate_joints_and_fix_bones(pred_joints, pred_bones, root_index=pred_root_index) | |
| else: # for Puppeteer or MagicArticulate spaital order | |
| pred_joints, pred_bones = merge_duplicate_joints_and_fix_bones(pred_joints, pred_bones) | |
| pred_root_index = None | |
| gt_root_index = int(batch_data_label['root_index'][0]) | |
| gt_joints, gt_bones, gt_root_index = merge_duplicate_joints_and_fix_bones(gt_joints, gt_bones, root_index=gt_root_index) # also merge duplicate joints/bones for GT to prevent NaNs in CD computation. | |
| if gt_bones.shape[0] == 0 or pred_bones.shape[0] == 0: | |
| continue | |
| ### calculate CD | |
| j2j_cd = chamfer_dist(pred_joints, gt_joints) | |
| j2b_cd = joint2bone_chamfer_dist(pred_joints, pred_bones, gt_joints, gt_bones) | |
| b2b_cd = bone2bone_chamfer_dist(pred_joints, pred_bones, gt_joints, gt_bones) | |
| if math.isnan(j2j_cd) or math.isnan(j2b_cd) or math.isnan(b2b_cd): | |
| print("NaN cd") | |
| else: | |
| num_valid += 1 | |
| avg_j2j_cd += j2j_cd | |
| avg_j2b_cd += j2b_cd | |
| avg_b2b_cd += b2b_cd | |
| print(f"For {batch_data_label['uuid'][0]}, J2J Chamfer Distance: {j2j_cd:.7f}, J2B Chamfer Distance: {j2b_cd:.7f}, B2B Chamfer Distance: {b2b_cd:.7f}, infer time: {infer_time_pre_mesh:.7f}") | |
| with open(results_file, 'a') as f: | |
| f.write(f"For {batch_data_label['uuid'][0]}, J2J Chamfer Distance: {j2j_cd:.7f}, J2B Chamfer Distance: {j2b_cd:.7f}, B2B Chamfer Distance: {b2b_cd:.7f}, infer time: {infer_time_pre_mesh:.7f}\n") | |
| if len(gt_samples) <= 30: # only save the first 30 results now, change to 2000 to save all | |
| pred_samples.append((pred_joints, pred_bones, pred_root_index)) | |
| gt_samples.append((gt_joints, gt_bones, batch_data_label['vertices'][0], batch_data_label['faces'][0], batch_data_label['transform_params'][0], batch_data_label['uuid'][0], gt_root_index)) | |
| with open(results_file, 'a') as f: | |
| f.write(f"Average J2J Chamfer Distance: {avg_j2j_cd/num_valid:.7f}\n") | |
| f.write(f"Average J2B Chamfer Distance: {avg_j2b_cd/num_valid:.7f}\n") | |
| f.write(f"Average B2B Chamfer Distance: {avg_b2b_cd/num_valid:.7f}\n") | |
| f.write(f"Average inference time: {np.mean(infer_all_time):.7f}\n") | |
| print(f"Valid generation: {num_valid}, Average J2J Chamfer Distance: {avg_j2j_cd/num_valid:.7f}, average J2B Chamfer Distance: {avg_j2b_cd/num_valid:.7f}, average B2B Chamfer Distance: {avg_b2b_cd/num_valid:.7f}, average infer time: {np.mean(infer_all_time):.7f}") | |
| # save results | |
| for i, ((pred_joints, pred_bones, pred_root_index), (gt_joints, gt_bones, vertices, faces, transform_params, file_name, gt_root_index)) in enumerate(zip(pred_samples, gt_samples)): | |
| pred_skel_filename = f'{output_dir}/{file_name}_skel_pred.obj' | |
| gt_skel_filename = f'{output_dir}/{file_name}_skel_gt.obj' | |
| mesh_filename = f'{output_dir}/{file_name}.obj' | |
| pred_rig_filename = f'{output_dir}/{file_name}_pred.txt' | |
| vertices = vertices.cpu().numpy() | |
| faces = faces.cpu().numpy() | |
| trans = transform_params[:3].cpu().numpy() | |
| scale = transform_params[3].cpu().numpy() | |
| pc_trans = transform_params[4:7].cpu().numpy() | |
| pc_scale = transform_params[7].cpu().numpy() | |
| # save skeleton to .txt, denormalize the skeletons to align with input meshes | |
| pred_joints_denorm = pred_joints * pc_scale + pc_trans # first align with point cloud | |
| pred_joints_denorm = pred_joints_denorm / scale + trans # then align with original mesh | |
| if args.joint_token: | |
| pred_root_index = save_skeleton_to_txt_joint(pred_joints_denorm, pred_bones, pred_rig_filename) | |
| else: | |
| save_skeleton_to_txt(pred_joints_denorm, pred_bones, pred_root_index, args.hier_order, vertices, pred_rig_filename) | |
| # save skeletons | |
| if args.hier_order or args.joint_token: | |
| save_skeleton_obj(pred_joints, pred_bones, pred_skel_filename, pred_root_index, use_cone=True) | |
| else: | |
| save_skeleton_obj(pred_joints, pred_bones, pred_skel_filename, use_cone=False) | |
| save_skeleton_obj(gt_joints, gt_bones, gt_skel_filename, gt_root_index, use_cone=True) | |
| # save mesh | |
| # when saving mesh and rendering, use normalized vertices (-0.5,0.5) | |
| vertices_norm = (vertices - trans) * scale | |
| vertices_norm = (vertices_norm - pc_trans) / pc_scale | |
| save_mesh(vertices_norm, faces, mesh_filename) | |
| # render mesh w/ skeleton | |
| if args.save_render: | |
| if args.hier_order or args.joint_token: | |
| render_mesh_with_skeleton(pred_joints, pred_bones, vertices_norm, faces, output_dir, file_name, prefix='pred', root_idx=pred_root_index) | |
| else: | |
| render_mesh_with_skeleton(pred_joints, pred_bones, vertices_norm, faces, output_dir, file_name, prefix='pred') | |
| render_mesh_with_skeleton(gt_joints, gt_bones, vertices_norm, faces, output_dir, file_name, prefix='gt', root_idx=gt_root_index) | |