Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import trimesh | |
| import random | |
| from transformers import AutoModelForImageSegmentation | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download, snapshot_download, login | |
| import subprocess | |
| import shutil | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 | |
| print("DEVICE: ", DEVICE) | |
| DEFAULT_PART_FACE_NUMBER = 10000 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| HOLOPART_REPO_URL = "https://github.com/VAST-AI-Research/HoloPart" | |
| HOLOPART_PRETRAINED_MODEL = "checkpoints/HoloPart" | |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| HOLOPART_CODE_DIR = "./holopart" | |
| if not os.path.exists(HOLOPART_REPO_URL): | |
| os.system(f"git clone {HOLOPART_REPO_URL} {HOLOPART_CODE_DIR}") | |
| import sys | |
| sys.path.append(HOLOPART_CODE_DIR) | |
| sys.path.append(os.path.join(HOLOPART_CODE_DIR, "scripts")) | |
| EXAMPLES = [ | |
| ["./holopart/assets/example_data/000.glb", "./holopart/assets/example_data/000.png"], | |
| ["./holopart/assets/example_data/001.glb", "./holopart/assets/example_data/001.png"], | |
| ["./holopart/assets/example_data/002.glb", "./holopart/assets/example_data/002.png"], | |
| ["./holopart/assets/example_data/003.glb", "./holopart/assets/example_data/003.png"], | |
| ] | |
| HEADER = """ | |
| # 🔮 Decompose a 3D shape into complete parts with [HoloPart](https://github.com/VAST-AI-Research/HoloPart). | |
| ### Step 1: Prepare Your Segmented Mesh | |
| Upload a mesh with part segmentation. We recommend using these segmentation tools: | |
| - [SAMPart3D](https://github.com/Pointcept/SAMPart3D) | |
| - [SAMesh](https://github.com/gtangg12/samesh) | |
| For a mesh file `mesh.glb` and corresponding face mask `mask.npy`, prepare your input using this Python code: | |
| ```python | |
| import trimesh | |
| import numpy as np | |
| mesh = trimesh.load("mesh.glb", force="mesh") | |
| mask_npy = np.load("mask.npy") | |
| mesh_parts = [] | |
| for part_id in np.unique(mask_npy): | |
| mesh_part = mesh.submesh([mask_npy == part_id], append=True) | |
| mesh_parts.append(mesh_part) | |
| mesh_parts = trimesh.Scene(mesh_parts).export("input_mesh.glb") | |
| ``` | |
| The resulting **input_mesh.glb** is your prepared input for HoloPart. | |
| ### Step 2: Click the Decompose Parts button to begin the decomposition process. | |
| """ | |
| from inference_holopart import prepare_data, run_holopart | |
| from holopart.pipelines.pipeline_holopart import HoloPartPipeline | |
| snapshot_download("VAST-AI/HoloPart", local_dir=HOLOPART_PRETRAINED_MODEL) | |
| holopart_pipe = HoloPartPipeline.from_pretrained(HOLOPART_PRETRAINED_MODEL).to(DEVICE, DTYPE) | |
| def start_session(req: gr.Request): | |
| save_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| os.makedirs(save_dir, exist_ok=True) | |
| print("start session, mkdir", save_dir) | |
| def end_session(req: gr.Request): | |
| save_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| shutil.rmtree(save_dir) | |
| def get_random_hex(): | |
| random_bytes = os.urandom(8) | |
| random_hex = random_bytes.hex() | |
| return random_hex | |
| def get_random_seed(randomize_seed, seed): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| return seed | |
| def explode_mesh(mesh: trimesh.Scene, explode_factor: float = 0.5): | |
| center = mesh.centroid | |
| exploded_mesh = trimesh.Scene() | |
| for geometry_name, geometry in mesh.geometry.items(): | |
| transform = mesh.graph[geometry_name][0] | |
| vertices_global = trimesh.transformations.transform_points( | |
| geometry.vertices, transform) | |
| part_center = np.mean(vertices_global, axis=0) | |
| direction = part_center - center | |
| direction_length = np.linalg.norm(direction) | |
| if direction_length > 0: | |
| direction = direction / direction_length | |
| displacement = direction * explode_factor | |
| new_transform = np.copy(transform) | |
| new_transform[:3, 3] += displacement | |
| exploded_mesh.add_geometry(geometry, transform=new_transform, geom_name=geometry_name) | |
| return exploded_mesh | |
| def run_full(data_path, seed=42, num_inference_steps=25, guidance_scale=3.5): | |
| batch_size = 30 | |
| parts_data = prepare_data(data_path) | |
| part_scene = run_holopart( | |
| holopart_pipe, | |
| batch=parts_data, | |
| batch_size=batch_size, | |
| seed=seed, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| num_chunks=1000000, | |
| ) | |
| print("mesh extraction done") | |
| save_dir = os.path.join(TMP_DIR, "examples") | |
| os.makedirs(save_dir, exist_ok=True) | |
| mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb") | |
| part_scene.export(mesh_path) | |
| print("save to ", mesh_path) | |
| exploded_mesh = explode_mesh(part_scene, 0.7) | |
| exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb") | |
| exploded_mesh.export(exploded_mesh_path) | |
| torch.cuda.empty_cache() | |
| return mesh_path, exploded_mesh_path | |
| def run_example(data_path: str, example_image_path, seed=42, num_inference_steps=25, guidance_scale=3.5): | |
| batch_size = 30 | |
| parts_data = prepare_data(data_path) | |
| part_scene = run_holopart( | |
| holopart_pipe, | |
| batch=parts_data, | |
| batch_size=batch_size, | |
| seed=seed, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| num_chunks=1000000, | |
| ) | |
| print("mesh extraction done") | |
| save_dir = os.path.join(TMP_DIR, "examples") | |
| os.makedirs(save_dir, exist_ok=True) | |
| mesh_path = os.path.join(save_dir, f"holorpart_{get_random_hex()}.glb") | |
| part_scene.export(mesh_path) | |
| print("save to ", mesh_path) | |
| exploded_mesh = explode_mesh(part_scene, 0.5) | |
| exploded_mesh_path = os.path.join(save_dir, f"holorpart_exploded_{get_random_hex()}.glb") | |
| exploded_mesh.export(exploded_mesh_path) | |
| torch.cuda.empty_cache() | |
| return mesh_path, exploded_mesh_path | |
| with gr.Blocks(title="HoloPart") as demo: | |
| gr.Markdown(HEADER) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_mesh = gr.Model3D(label="Input Mesh") | |
| example_image = gr.Image(label="Example Image", type="filepath", interactive=False, visible=False) | |
| # seg_image = gr.Image( | |
| # label="Segmentation Result", type="pil", format="png", interactive=False | |
| # ) | |
| with gr.Accordion("Generation Settings", open=True): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=0, | |
| value=0 | |
| ) | |
| # randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=8, | |
| maximum=50, | |
| step=1, | |
| value=25, | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="CFG scale", | |
| minimum=0.0, | |
| maximum=20.0, | |
| step=0.1, | |
| value=3.5, | |
| ) | |
| with gr.Row(): | |
| reduce_face = gr.Checkbox(label="Simplify Mesh", value=True, interactive=False) | |
| # target_face_num = gr.Slider(maximum=1000000, minimum=10000, value=DEFAULT_FACE_NUMBER, label="Target Face Number") | |
| gen_button = gr.Button("Decompose Parts", variant="primary") | |
| with gr.Column(): | |
| model_output = gr.Model3D(label="Decomposed GLB", interactive=False) | |
| exploded_parts_output = gr.Model3D(label="Exploded Parts", interactive=False) | |
| with gr.Row(): | |
| examples = gr.Examples( | |
| examples=EXAMPLES, | |
| fn=run_example, | |
| inputs=[input_mesh, example_image], | |
| outputs=[model_output, exploded_parts_output], | |
| cache_examples=True, | |
| ) | |
| gen_button.click( | |
| run_full, | |
| inputs=[ | |
| input_mesh, | |
| seed, | |
| num_inference_steps, | |
| guidance_scale | |
| ], | |
| outputs=[model_output, exploded_parts_output], | |
| ) | |
| demo.load(start_session) | |
| demo.unload(end_session) | |
| demo.launch() | |