|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Input processor for Depth Anything 3 (parallelized). |
|
|
|
|
|
This version removes the square center-crop step for "*crop" methods (same as your note). |
|
|
In addition, it parallelizes per-image preprocessing using the provided `parallel_execution`. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
|
|
|
from depth_anything_3.utils.logger import logger |
|
|
from depth_anything_3.utils.parallel_utils import parallel_execution |
|
|
|
|
|
|
|
|
class InputProcessor: |
|
|
"""Prepares a batch of images for model inference. |
|
|
This processor converts a list of image file paths into a single, model-ready |
|
|
tensor. The processing pipeline is executed in parallel across multiple workers |
|
|
for efficiency. |
|
|
|
|
|
Pipeline: |
|
|
1) Load image and convert to RGB |
|
|
2) Boundary resize (upper/lower bound, preserving aspect ratio) |
|
|
3) Enforce divisibility by PATCH_SIZE: |
|
|
- "*resize" methods: each dimension is rounded to nearest multiple |
|
|
(may up/downscale a few px) |
|
|
- "*crop" methods: each dimension is floored to nearest multiple via center crop |
|
|
4) Convert to tensor and apply ImageNet normalization |
|
|
5) Stack into (1, N, 3, H, W) |
|
|
|
|
|
Parallelization: |
|
|
- Each image is processed independently in a worker. |
|
|
- Order of outputs matches the input order. |
|
|
""" |
|
|
|
|
|
NORMALIZE = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
PATCH_SIZE = 14 |
|
|
|
|
|
def __init__(self): |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
image: list[np.ndarray | Image.Image | str], |
|
|
extrinsics: np.ndarray | None = None, |
|
|
intrinsics: np.ndarray | None = None, |
|
|
process_res: int = 504, |
|
|
process_res_method: str = "upper_bound_resize", |
|
|
*, |
|
|
num_workers: int = 8, |
|
|
print_progress: bool = False, |
|
|
sequential: bool | None = None, |
|
|
desc: str | None = "Preprocess", |
|
|
perform_normalization: bool = True, |
|
|
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: |
|
|
""" |
|
|
Returns: |
|
|
(tensor, extrinsics_list, intrinsics_list) |
|
|
tensor shape: (1, N, 3, H, W) |
|
|
If perform_normalization is False, tensor is uint8 (0-255). |
|
|
If perform_normalization is True, tensor is float32 normalized (ImageNet). |
|
|
""" |
|
|
sequential = self._resolve_sequential(sequential, num_workers) |
|
|
exts_list, ixts_list = self._validate_and_pack_meta(image, extrinsics, intrinsics) |
|
|
|
|
|
results = self._run_parallel( |
|
|
image=image, |
|
|
exts_list=exts_list, |
|
|
ixts_list=ixts_list, |
|
|
process_res=process_res, |
|
|
process_res_method=process_res_method, |
|
|
num_workers=num_workers, |
|
|
print_progress=print_progress, |
|
|
sequential=sequential, |
|
|
desc=desc, |
|
|
perform_normalization=perform_normalization, |
|
|
) |
|
|
|
|
|
proc_imgs, out_sizes, out_ixts, out_exts = self._unpack_results(results) |
|
|
proc_imgs, out_sizes, out_ixts = self._unify_batch_shapes(proc_imgs, out_sizes, out_ixts) |
|
|
|
|
|
batch_tensor = self._stack_batch(proc_imgs) |
|
|
|
|
|
|
|
|
|
|
|
out_exts = ( |
|
|
torch.from_numpy(np.ascontiguousarray(np.asarray(out_exts))).float() |
|
|
if out_exts is not None and out_exts[0] is not None |
|
|
else None |
|
|
) |
|
|
out_ixts = ( |
|
|
torch.from_numpy(np.ascontiguousarray(np.asarray(out_ixts))).float() |
|
|
if out_ixts is not None and out_ixts[0] is not None |
|
|
else None |
|
|
) |
|
|
return (batch_tensor, out_exts, out_ixts) |
|
|
|
|
|
@staticmethod |
|
|
def normalize_tensor(tensor: torch.Tensor, mean: torch.Tensor | list, std: torch.Tensor | list) -> torch.Tensor: |
|
|
"""Normalize a tensor (C, H, W) or (B, C, H, W) with given mean and std. |
|
|
Expects input tensor to be float32 in range [0, 1]. |
|
|
""" |
|
|
if isinstance(mean, list): |
|
|
mean = torch.tensor(mean, device=tensor.device, dtype=tensor.dtype).view(-1, 1, 1) |
|
|
if isinstance(std, list): |
|
|
std = torch.tensor(std, device=tensor.device, dtype=tensor.dtype).view(-1, 1, 1) |
|
|
|
|
|
|
|
|
if tensor.dim() == 4 and mean.dim() == 3: |
|
|
mean = mean.unsqueeze(0) |
|
|
std = std.unsqueeze(0) |
|
|
|
|
|
return (tensor - mean) / std |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_sequential(self, sequential: bool | None, num_workers: int) -> bool: |
|
|
return (num_workers <= 1) if sequential is None else sequential |
|
|
|
|
|
def _validate_and_pack_meta( |
|
|
self, |
|
|
images: list[np.ndarray | Image.Image | str], |
|
|
extrinsics: np.ndarray | None, |
|
|
intrinsics: np.ndarray | None, |
|
|
) -> tuple[list[np.ndarray | None] | None, list[np.ndarray | None] | None]: |
|
|
if extrinsics is not None and len(extrinsics) != len(images): |
|
|
raise ValueError("Length of extrinsics must match images when provided.") |
|
|
if intrinsics is not None and len(intrinsics) != len(images): |
|
|
raise ValueError("Length of intrinsics must match images when provided.") |
|
|
exts_list = [e for e in extrinsics] if extrinsics is not None else None |
|
|
ixts_list = [k for k in intrinsics] if intrinsics is not None else None |
|
|
return exts_list, ixts_list |
|
|
|
|
|
def _run_parallel( |
|
|
self, |
|
|
*, |
|
|
image: list[np.ndarray | Image.Image | str], |
|
|
exts_list: list[np.ndarray | None] | None, |
|
|
ixts_list: list[np.ndarray | None] | None, |
|
|
process_res: int, |
|
|
process_res_method: str, |
|
|
num_workers: int, |
|
|
print_progress: bool, |
|
|
sequential: bool, |
|
|
desc: str | None, |
|
|
perform_normalization: bool, |
|
|
): |
|
|
results = parallel_execution( |
|
|
image, |
|
|
exts_list, |
|
|
ixts_list, |
|
|
action=self._process_one, |
|
|
num_processes=num_workers, |
|
|
print_progress=print_progress, |
|
|
sequential=sequential, |
|
|
desc=desc, |
|
|
process_res=process_res, |
|
|
process_res_method=process_res_method, |
|
|
perform_normalization=perform_normalization, |
|
|
) |
|
|
if not results: |
|
|
raise RuntimeError( |
|
|
"No preprocessing results returned. Check inputs and parallel_execution." |
|
|
) |
|
|
return results |
|
|
|
|
|
def _unpack_results(self, results): |
|
|
""" |
|
|
results: List[Tuple[torch.Tensor, Tuple[H, W], Optional[np.ndarray], Optional[np.ndarray]]] |
|
|
-> processed_images, out_sizes, out_intrinsics, out_extrinsics |
|
|
""" |
|
|
try: |
|
|
processed_images, out_sizes, out_intrinsics, out_extrinsics = zip(*results) |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"Unexpected results structure from parallel_execution: " |
|
|
f"{type(results)} / sample: {results[0]}" |
|
|
) from e |
|
|
|
|
|
return list(processed_images), list(out_sizes), list(out_intrinsics), list(out_extrinsics) |
|
|
|
|
|
def _unify_batch_shapes( |
|
|
self, |
|
|
processed_images: list[torch.Tensor], |
|
|
out_sizes: list[tuple[int, int]], |
|
|
out_intrinsics: list[np.ndarray | None], |
|
|
) -> tuple[list[torch.Tensor], list[tuple[int, int]], list[np.ndarray | None]]: |
|
|
"""Center-crop all tensors to the smallest H, W; adjust intrinsics' cx, cy accordingly.""" |
|
|
if len(set(out_sizes)) <= 1: |
|
|
return processed_images, out_sizes, out_intrinsics |
|
|
|
|
|
min_h = min(h for h, _ in out_sizes) |
|
|
min_w = min(w for _, w in out_sizes) |
|
|
logger.warn( |
|
|
f"Images in batch have different sizes {out_sizes}; " |
|
|
f"center-cropping all to smallest ({min_h},{min_w})" |
|
|
) |
|
|
|
|
|
center_crop = T.CenterCrop((min_h, min_w)) |
|
|
new_imgs, new_sizes, new_ixts = [], [], [] |
|
|
for img_t, (H, W), K in zip(processed_images, out_sizes, out_intrinsics): |
|
|
crop_top = max(0, (H - min_h) // 2) |
|
|
crop_left = max(0, (W - min_w) // 2) |
|
|
new_imgs.append(center_crop(img_t)) |
|
|
new_sizes.append((min_h, min_w)) |
|
|
if K is None: |
|
|
new_ixts.append(None) |
|
|
else: |
|
|
K_adj = K.copy() |
|
|
K_adj[0, 2] -= crop_left |
|
|
K_adj[1, 2] -= crop_top |
|
|
new_ixts.append(K_adj) |
|
|
return new_imgs, new_sizes, new_ixts |
|
|
|
|
|
def _stack_batch(self, processed_images: list[torch.Tensor]) -> torch.Tensor: |
|
|
return torch.stack(processed_images) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_one( |
|
|
self, |
|
|
img: np.ndarray | Image.Image | str, |
|
|
extrinsic: np.ndarray | None = None, |
|
|
intrinsic: np.ndarray | None = None, |
|
|
*, |
|
|
process_res: int, |
|
|
process_res_method: str, |
|
|
perform_normalization: bool = True, |
|
|
) -> tuple[torch.Tensor, tuple[int, int], np.ndarray | None, np.ndarray | None]: |
|
|
|
|
|
pil_img = self._load_image(img) |
|
|
orig_w, orig_h = pil_img.size |
|
|
|
|
|
|
|
|
pil_img = self._resize_image(pil_img, process_res, process_res_method) |
|
|
w, h = pil_img.size |
|
|
intrinsic = self._resize_ixt(intrinsic, orig_w, orig_h, w, h) |
|
|
|
|
|
|
|
|
if process_res_method.endswith("resize"): |
|
|
pil_img = self._make_divisible_by_resize(pil_img, self.PATCH_SIZE) |
|
|
new_w, new_h = pil_img.size |
|
|
intrinsic = self._resize_ixt(intrinsic, w, h, new_w, new_h) |
|
|
w, h = new_w, new_h |
|
|
elif process_res_method.endswith("crop"): |
|
|
pil_img = self._make_divisible_by_crop(pil_img, self.PATCH_SIZE) |
|
|
new_w, new_h = pil_img.size |
|
|
intrinsic = self._crop_ixt(intrinsic, w, h, new_w, new_h) |
|
|
w, h = new_w, new_h |
|
|
else: |
|
|
raise ValueError(f"Unsupported process_res_method: {process_res_method}") |
|
|
|
|
|
if perform_normalization: |
|
|
|
|
|
img_tensor = self._normalize_image(pil_img) |
|
|
else: |
|
|
|
|
|
|
|
|
arr = np.array(pil_img) |
|
|
img_tensor = torch.from_numpy(arr).permute(2, 0, 1) |
|
|
|
|
|
_, H, W = img_tensor.shape |
|
|
assert (W, H) == (w, h), "Tensor size mismatch with PIL image size after processing." |
|
|
|
|
|
|
|
|
return img_tensor, (H, W), intrinsic, extrinsic |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resize_ixt( |
|
|
self, |
|
|
intrinsic: np.ndarray | None, |
|
|
orig_w: int, |
|
|
orig_h: int, |
|
|
w: int, |
|
|
h: int, |
|
|
) -> np.ndarray | None: |
|
|
if intrinsic is None: |
|
|
return None |
|
|
K = intrinsic.copy() |
|
|
|
|
|
K[:1] *= w / float(orig_w) |
|
|
K[1:2] *= h / float(orig_h) |
|
|
return K |
|
|
|
|
|
def _crop_ixt( |
|
|
self, |
|
|
intrinsic: np.ndarray | None, |
|
|
orig_w: int, |
|
|
orig_h: int, |
|
|
w: int, |
|
|
h: int, |
|
|
) -> np.ndarray | None: |
|
|
if intrinsic is None: |
|
|
return None |
|
|
K = intrinsic.copy() |
|
|
crop_h = (orig_h - h) // 2 |
|
|
crop_w = (orig_w - w) // 2 |
|
|
K[0, 2] -= crop_w |
|
|
K[1, 2] -= crop_h |
|
|
return K |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_image(self, img: np.ndarray | Image.Image | str) -> Image.Image: |
|
|
if isinstance(img, str): |
|
|
return Image.open(img).convert("RGB") |
|
|
elif isinstance(img, np.ndarray): |
|
|
|
|
|
return Image.fromarray(img).convert("RGB") |
|
|
elif isinstance(img, Image.Image): |
|
|
return img.convert("RGB") |
|
|
else: |
|
|
raise ValueError(f"Unsupported image type: {type(img)}") |
|
|
|
|
|
def _normalize_image(self, img: Image.Image) -> torch.Tensor: |
|
|
img_tensor = T.ToTensor()(img) |
|
|
return self.NORMALIZE(img_tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resize_image(self, img: Image.Image, target_size: int, method: str) -> Image.Image: |
|
|
if method in ("upper_bound_resize", "upper_bound_crop"): |
|
|
return self._resize_longest_side(img, target_size) |
|
|
elif method in ("lower_bound_resize", "lower_bound_crop"): |
|
|
return self._resize_shortest_side(img, target_size) |
|
|
else: |
|
|
raise ValueError(f"Unsupported resize method: {method}") |
|
|
|
|
|
def _resize_longest_side(self, img: Image.Image, target_size: int) -> Image.Image: |
|
|
w, h = img.size |
|
|
longest = max(w, h) |
|
|
if longest == target_size: |
|
|
return img |
|
|
scale = target_size / float(longest) |
|
|
new_w = max(1, int(round(w * scale))) |
|
|
new_h = max(1, int(round(h * scale))) |
|
|
interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA |
|
|
arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation) |
|
|
return Image.fromarray(arr) |
|
|
|
|
|
def _resize_shortest_side(self, img: Image.Image, target_size: int) -> Image.Image: |
|
|
w, h = img.size |
|
|
shortest = min(w, h) |
|
|
if shortest == target_size: |
|
|
return img |
|
|
scale = target_size / float(shortest) |
|
|
new_w = max(1, int(round(w * scale))) |
|
|
new_h = max(1, int(round(h * scale))) |
|
|
interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA |
|
|
arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation) |
|
|
return Image.fromarray(arr) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_divisible_by_crop(self, img: Image.Image, patch: int) -> Image.Image: |
|
|
""" |
|
|
Floor each dimension to the nearest multiple of PATCH_SIZE via center crop. |
|
|
Example: 504x377 -> 504x364 |
|
|
""" |
|
|
w, h = img.size |
|
|
new_w = (w // patch) * patch |
|
|
new_h = (h // patch) * patch |
|
|
if new_w == w and new_h == h: |
|
|
return img |
|
|
left = (w - new_w) // 2 |
|
|
top = (h - new_h) // 2 |
|
|
return img.crop((left, top, left + new_w, top + new_h)) |
|
|
|
|
|
def _make_divisible_by_resize(self, img: Image.Image, patch: int) -> Image.Image: |
|
|
""" |
|
|
Round each dimension to nearest multiple of PATCH_SIZE via small resize. |
|
|
""" |
|
|
w, h = img.size |
|
|
|
|
|
def nearest_multiple(x: int, p: int) -> int: |
|
|
down = (x // p) * p |
|
|
up = down + p |
|
|
return up if abs(up - x) <= abs(x - down) else down |
|
|
|
|
|
new_w = max(1, nearest_multiple(w, patch)) |
|
|
new_h = max(1, nearest_multiple(h, patch)) |
|
|
if new_w == w and new_h == h: |
|
|
return img |
|
|
upscale = (new_w > w) or (new_h > h) |
|
|
interpolation = cv2.INTER_CUBIC if upscale else cv2.INTER_AREA |
|
|
arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation) |
|
|
return Image.fromarray(arr) |
|
|
|
|
|
|
|
|
|
|
|
InputAdapter = InputProcessor |
|
|
|