| import torch |
| import base64 |
| import os |
| from PIL import Image |
| from io import BytesIO |
| from trellis.pipelines import TrellisImageTo3DPipeline |
| from trellis.utils import postprocessing_utils |
|
|
| class EndpointHandler: |
| def __init__(self, model_dir): |
| |
| self.pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") |
| self.pipeline.cuda() |
| |
| def __call__(self, data): |
| """ |
| Args: |
| data (:obj:`dict`): |
| - "inputs": The base64 encoded image or URL. |
| - "params": Dictionary of generation parameters (optional). |
| """ |
| inputs = data.pop("inputs", data) |
| params = data.pop("parameters", {}) |
|
|
| |
| image = Image.open(BytesIO(base64.b64decode(inputs))) |
| |
| |
| |
| outputs = self.pipeline( |
| image, |
| num_samples=1, |
| return_flags=["mesh"], |
| **params |
| ) |
|
|
| |
| mesh = outputs['mesh'][0] |
| glb_io = BytesIO() |
| mesh.export(glb_io, file_type='glb') |
| glb_io.seek(0) |
|
|
| |
| return { |
| "mesh_base64": base64.b64encode(glb_io.getvalue()).decode("utf-8"), |
| "format": "glb" |
| } |
|
|