Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import argparse | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torchvision.transforms import v2 | |
| from omegaconf import OmegaConf | |
| from einops import rearrange | |
| from tqdm import tqdm | |
| from huggingface_hub import hf_hub_download | |
| import sys | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| submodule_path = os.path.join(script_dir, "..", "external", "instant-mesh") | |
| sys.path.insert(0, submodule_path) | |
| from src.utils.camera_util import ( | |
| get_circular_camera_poses, | |
| get_zero123plus_input_cameras, | |
| FOV_to_intrinsics, | |
| ) | |
| from src.utils.train_util import instantiate_from_config | |
| from src.utils.mesh_util import save_obj | |
| from src.utils.infer_util import save_video | |
| def get_render_cameras( | |
| batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False | |
| ): | |
| c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation) | |
| if is_flexicubes: | |
| cameras = torch.linalg.inv(c2ws) | |
| cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1) | |
| else: | |
| extrinsics = c2ws.flatten(-2) | |
| intrinsics = ( | |
| FOV_to_intrinsics(30.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2) | |
| ) | |
| cameras = torch.cat([extrinsics, intrinsics], dim=-1) | |
| cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1) | |
| return cameras | |
| def render_frames( | |
| model, planes, render_cameras, render_size=512, chunk_size=1, is_flexicubes=False | |
| ): | |
| frames = [] | |
| for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): | |
| if is_flexicubes: | |
| frame = model.forward_geometry( | |
| planes, render_cameras[:, i : i + chunk_size], render_size=render_size | |
| )["img"] | |
| else: | |
| frame = model.forward_synthesizer( | |
| planes, render_cameras[:, i : i + chunk_size], render_size=render_size | |
| )["images_rgb"] | |
| frames.append(frame) | |
| frames = torch.cat(frames, dim=1)[0] | |
| return frames | |
| def main(args): | |
| """ | |
| Main function to run the 3D mesh generation process. | |
| """ | |
| # ============================ | |
| # CONFIG | |
| # ============================ | |
| print("π Starting 3D mesh generation...") | |
| config = OmegaConf.load(args.config) | |
| config_name = os.path.basename(args.config).replace(".yaml", "") | |
| model_config = config.model_config | |
| infer_config = config.infer_config | |
| IS_FLEXICUBES = config_name.startswith("instant-mesh") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # ============================ | |
| # SETUP OUTPUT DIRECTORY | |
| # ============================ | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| base_name = os.path.splitext(os.path.basename(args.input_file))[0] | |
| mesh_path = os.path.join(args.output_dir, "recon.obj") | |
| video_path = os.path.join(args.output_dir, "recon.mp4") | |
| # ============================ | |
| # LOAD RECONSTRUCTION MODEL | |
| # ============================ | |
| print("Loading reconstruction model...") | |
| model = instantiate_from_config(model_config) | |
| # Download model checkpoint if it doesn't exist | |
| model_ckpt_path = ( | |
| infer_config.model_path | |
| if os.path.exists(infer_config.model_path) | |
| else hf_hub_download( | |
| repo_id="TencentARC/InstantMesh", | |
| filename=f"{config_name.replace('-', '_')}.ckpt", | |
| repo_type="model", | |
| ) | |
| ) | |
| # Load the state dictionary | |
| state_dict = torch.load(model_ckpt_path, map_location="cpu")["state_dict"] | |
| state_dict = { | |
| k[14:]: v for k, v in state_dict.items() if k.startswith("lrm_generator.") | |
| } | |
| model.load_state_dict(state_dict, strict=True) | |
| model = model.to(device).eval() | |
| if IS_FLEXICUBES: | |
| model.init_flexicubes_geometry(device, fovy=30.0) | |
| # ============================ | |
| # PREPARE DATA | |
| # ============================ | |
| print(f"Processing input file: {args.input_file}") | |
| # Load and preprocess the input image | |
| input_image = Image.open(args.input_file).convert("RGB") | |
| images = np.asarray(input_image, dtype=np.float32) / 255.0 | |
| images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() | |
| # Rearrange from (C, H, W) to (B, C, H, W) where B is the number of views | |
| images = rearrange(images, "c (n h) (m w) -> (n m) c h w", n=3, m=2) | |
| images = images.unsqueeze(0).to(device) | |
| images = v2.functional.resize(images, size=320, interpolation=3, antialias=True).clamp(0, 1) | |
| input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0 * args.scale).to(device) | |
| # ============================ | |
| # RUN INFERENCE AND SAVE OUTPUT | |
| # ============================ | |
| with torch.no_grad(): | |
| # Generate 3D mesh | |
| planes = model.forward_planes(images, input_cameras) | |
| mesh_out = model.extract_mesh(planes, use_texture_map=False, **infer_config) | |
| # Save the mesh | |
| vertices, faces, vertex_colors = mesh_out | |
| save_obj(vertices, faces, vertex_colors, mesh_path) | |
| print(f"β Mesh saved to {mesh_path}") | |
| # Render and save video if enabled | |
| if args.save_video: | |
| print("π₯ Rendering video...") | |
| render_size = infer_config.render_resolution | |
| chunk_size = 20 if IS_FLEXICUBES else 1 | |
| render_cameras = get_render_cameras( | |
| batch_size=1, | |
| M=120, | |
| radius=args.distance, | |
| elevation=20.0, | |
| is_flexicubes=IS_FLEXICUBES, | |
| ).to(device) | |
| frames = render_frames( | |
| model=model, | |
| planes=planes, | |
| render_cameras=render_cameras, | |
| render_size=render_size, | |
| chunk_size=chunk_size, | |
| is_flexicubes=IS_FLEXICUBES, | |
| ) | |
| save_video(frames, video_path, fps=30) | |
| print(f"β Video saved to {video_path}") | |
| print("β¨ Process complete.") | |
| if __name__ == "__main__": | |
| # ============================ | |
| # SCRIPT ARGUMENTS | |
| # ============================ | |
| parser = argparse.ArgumentParser( | |
| description="Generate a 3D mesh and video from a single multi-view PNG file." | |
| ) | |
| # Positional argument for config file | |
| parser.add_argument( | |
| "config", | |
| type=str, | |
| help="Path to the model config file (.yaml)." | |
| ) | |
| # Required file paths | |
| parser.add_argument( | |
| "--input_file", | |
| type=str, | |
| required=True, | |
| help="Path to the input PNG file." | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="outputs/", | |
| help="Directory to save the output .obj and .mp4 files. Defaults to 'outputs/'." | |
| ) | |
| # Optional parameters for model and rendering | |
| parser.add_argument( | |
| "--scale", | |
| type=float, | |
| default=1.0, | |
| help="Scale of the input cameras." | |
| ) | |
| parser.add_argument( | |
| "--distance", | |
| type=float, | |
| default=4.5, | |
| help="Camera distance for rendering the output video." | |
| ) | |
| parser.add_argument( | |
| "--no_video", | |
| dest="save_video", | |
| action="store_false", | |
| help="If set, disables saving the output .mp4 video." | |
| ) | |
| parsed_args = parser.parse_args() | |
| main(parsed_args) | |