| | import subprocess |
| | import importlib |
| | import sys |
| | import logging |
| | from transformers import BaseImageProcessorFast |
| | import torch |
| | import numpy as np |
| | from rembg import remove, new_session |
| | from functools import partial |
| | from torchvision.utils import save_image |
| | from PIL import Image |
| | from kiui.op import recenter |
| | import kiui |
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | class LRMImageProcessor(BaseImageProcessorFast): |
| | def __init__(self, source_size=512, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.source_size = source_size |
| | self.session = None |
| | self.rembg_remove = None |
| |
|
| | |
| | def _initialize_session(self): |
| | if self.session is None: |
| | self.session = new_session("isnet-general-use") |
| | self.rembg_remove = partial(remove, session=self.session) |
| |
|
| | def preprocess_image(self, image): |
| | self._initialize_session() |
| | image = np.array(image) |
| | image = self.rembg_remove(image) |
| | mask = self.rembg_remove(image, only_mask=True) |
| | image = recenter(image, mask, border_ratio=0.20) |
| | image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255.0 |
| | if image.shape[1] == 4: |
| | image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...]) |
| | image = torch.nn.functional.interpolate(image, size=(self.source_size, self.source_size), mode='bicubic', align_corners=True) |
| | image = torch.clamp(image, 0, 1) |
| | return image |
| |
|
| | def get_normalized_camera_intrinsics(self, intrinsics: torch.Tensor): |
| | fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1] |
| | cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1] |
| | width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1] |
| | fx, fy = fx / width, fy / height |
| | cx, cy = cx / width, cy / height |
| | return fx, fy, cx, cy |
| |
|
| | def build_camera_principle(self, RT: torch.Tensor, intrinsics: torch.Tensor): |
| | fx, fy, cx, cy = self.get_normalized_camera_intrinsics(intrinsics) |
| | return torch.cat([ |
| | RT.reshape(-1, 12), |
| | fx.unsqueeze(-1), |
| | fy.unsqueeze(-1), |
| | cx.unsqueeze(-1), |
| | cy.unsqueeze(-1), |
| | ], dim=-1) |
| |
|
| | def _default_intrinsics(self): |
| | fx = fy = 384 |
| | cx = cy = 256 |
| | w = h = 512 |
| | intrinsics = torch.tensor([ |
| | [fx, fy], |
| | [cx, cy], |
| | [w, h], |
| | ], dtype=torch.float32) |
| | return intrinsics |
| |
|
| | def _default_source_camera(self, batch_size: int = 1): |
| | dist_to_center = 1.5 |
| | canonical_camera_extrinsics = torch.tensor([[ |
| | [0, 0, 1, 1], |
| | [1, 0, 0, 0], |
| | [0, 1, 0, 0], |
| | ]], dtype=torch.float32) |
| | canonical_camera_intrinsics = self._default_intrinsics().unsqueeze(0) |
| | source_camera = self.build_camera_principle(canonical_camera_extrinsics, canonical_camera_intrinsics) |
| | return source_camera.repeat(batch_size, 1) |
| |
|
| | def __call__(self, image, *args, **kwargs): |
| | processed_image = self.preprocess_image(image) |
| | source_camera = self._default_source_camera(batch_size=1) |
| | return processed_image, source_camera |