| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| _PACKAGE_ROOT = Path(__file__).resolve().parent | |
| DA3_BF16_MODEL = "depth/depth_anything_v3_vitl_bf16.safetensors" | |
| DA3_METRIC_BF16_MODEL = "depth/depth_anything_v3_metric_large_bf16.safetensors" | |
| def resolve_da3_chunk_size(chunk_size=-1, device=None): | |
| chunk_size = int(chunk_size if chunk_size is not None else -1) | |
| if chunk_size != -1: | |
| return chunk_size | |
| if not torch.cuda.is_available(): | |
| return 33 | |
| device = torch.device("cuda" if device is None else device) | |
| if device.type != "cuda": | |
| return 33 | |
| device_index = torch.cuda.current_device() if device.index is None else device.index | |
| vram_gb = torch.cuda.get_device_properties(device_index).total_memory / 1_000_000_000 | |
| if vram_gb < 8: | |
| return 33 | |
| if vram_gb < 24: | |
| return 65 | |
| return 97 | |
| def _load_da3(pretrained_model, device, model_name="da3-large"): | |
| from mmgp import offload | |
| from safetensors import safe_open | |
| from .api import DepthAnything3 | |
| model = DepthAnything3(model_name=model_name) | |
| pretrained_model = str(pretrained_model) | |
| if not pretrained_model.endswith(".safetensors"): | |
| raise ValueError(f"Depth Anything 3 now expects the bf16 safetensors checkpoint, got: {pretrained_model}") | |
| model_keys = set(model.state_dict().keys()) | |
| with safe_open(pretrained_model, framework="pt", device="cpu") as f: | |
| checkpoint_keys = set(f.keys()) | |
| missing = sorted(model_keys - checkpoint_keys) | |
| unexpected = sorted(checkpoint_keys - model_keys) | |
| allowed_missing = tuple(f"model.head.scratch.output_conv2_aux.{idx}.2." for idx in range(1, 4)) | |
| unsupported_missing = [key for key in missing if not key.startswith(allowed_missing)] | |
| if unexpected or unsupported_missing: | |
| raise RuntimeError(f"Unexpected DA3 checkpoint keys: unexpected={unexpected}, missing={unsupported_missing}") | |
| offload.load_model_data(model, pretrained_model, writable_tensors=False, default_dtype=torch.bfloat16, ignore_missing_keys=True) | |
| model.requires_grad_(False) | |
| model.to(device=device, dtype=torch.bfloat16) | |
| model.eval() | |
| return model | |
| def _resize_2d(array, height, width, mode="bilinear", inverse=False): | |
| if array.shape[-2:] == (height, width): | |
| return array.copy() | |
| dtype = array.dtype | |
| tensor = torch.from_numpy(array).to(torch.float64) | |
| leading = tensor.shape[:-2] | |
| tensor = tensor.reshape(-1, *tensor.shape[-2:]) | |
| if inverse: | |
| tensor = 1 / tensor | |
| tensor = F.interpolate(tensor[:, None], size=(height, width), mode=mode)[:, 0] | |
| if inverse: | |
| tensor = 1 / tensor | |
| tensor = tensor.reshape(*leading, height, width) | |
| if dtype == np.bool_: | |
| tensor = tensor >= 0.5 | |
| return tensor.numpy().astype(dtype) | |
| def _k_to_intrinsics(k): | |
| intrinsics = np.zeros((k.shape[0], 4), dtype=np.float32) | |
| intrinsics[:, 0] = k[:, 0, 0] | |
| intrinsics[:, 1] = k[:, 1, 1] | |
| intrinsics[:, 2] = k[:, 0, 2] | |
| intrinsics[:, 3] = k[:, 1, 2] | |
| return intrinsics | |
| def _prediction_to_arrays(prediction, height, width): | |
| depths = prediction.depth.astype(np.float32) | |
| sky = getattr(prediction, "sky", None) | |
| if sky is None: | |
| sky = np.zeros_like(depths, dtype=np.bool_) | |
| else: | |
| sky = sky.astype(np.bool_) | |
| cam_w2c = prediction.extrinsics.astype(np.float32) | |
| intrinsics = _k_to_intrinsics(prediction.intrinsics.astype(np.float32)) | |
| processed = prediction.processed_images | |
| proc_h, proc_w = processed.shape[1:3] | |
| depths = _resize_2d(depths, height, width, mode="bilinear", inverse=True) | |
| sky = _resize_2d(sky, height, width, mode="nearest", inverse=False) | |
| intrinsics[:, 0::2] *= width / proc_w | |
| intrinsics[:, 1::2] *= height / proc_h | |
| return depths, sky, cam_w2c, intrinsics | |
| def _camera_w2c_to_c2w(cam_w2c): | |
| cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float32) | |
| cam_w2c_44[:, :3, :4] = cam_w2c | |
| cam_w2c_44[:, 3, 3] = 1.0 | |
| cam_c2w = np.linalg.inv(cam_w2c_44) | |
| return (np.linalg.inv(cam_c2w[0])[None] @ cam_c2w).astype(np.float32) | |
| def _w2c_to_pose(cam_w2c): | |
| cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float64) | |
| cam_w2c_44[:, :3, :4] = cam_w2c.astype(np.float64) | |
| cam_w2c_44[:, 3, 3] = 1.0 | |
| return np.linalg.inv(cam_w2c_44) | |
| def _closest_rotation(matrix): | |
| u, _, vh = np.linalg.svd(matrix) | |
| rotation = u @ vh | |
| if np.linalg.det(rotation) < 0: | |
| u[:, -1] *= -1 | |
| rotation = u @ vh | |
| return rotation | |
| def _pose_based_chunk_alignment(ref_w2c, est_w2c): | |
| ref_pose = _w2c_to_pose(ref_w2c) | |
| est_pose = _w2c_to_pose(est_w2c) | |
| rotation = _closest_rotation(np.mean(ref_pose[:, :3, :3] @ np.swapaxes(est_pose[:, :3, :3], -1, -2), axis=0)) | |
| ref_centers = ref_pose[:, :3, 3] | |
| est_centers = est_pose[:, :3, 3] | |
| pair_i, pair_j = np.triu_indices(ref_centers.shape[0], k=1) | |
| ref_dists = np.linalg.norm(ref_centers[pair_i] - ref_centers[pair_j], axis=1) | |
| est_dists = np.linalg.norm(est_centers[pair_i] - est_centers[pair_j], axis=1) | |
| valid = est_dists > np.finfo(np.float64).eps | |
| scale = float(np.median(ref_dists[valid] / est_dists[valid])) if valid.any() else 1.0 | |
| est_mean = est_centers.mean(axis=0) | |
| ref_mean = ref_centers.mean(axis=0) | |
| translation = ref_mean - scale * (rotation @ est_mean) | |
| return rotation.astype(np.float32), translation.astype(np.float32), np.float32(scale) | |
| def _apply_sim3_to_w2c(cam_w2c, rotation, translation, scale): | |
| cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float32) | |
| cam_w2c_44[:, :3, :4] = cam_w2c | |
| cam_w2c_44[:, 3, 3] = 1.0 | |
| poses = np.linalg.inv(cam_w2c_44) | |
| aligned = poses.copy() | |
| aligned[:, :3, :3] = rotation @ poses[:, :3, :3] | |
| aligned[:, :3, 3] = (rotation @ (scale * poses[:, :3, 3]).T).T + translation | |
| return np.linalg.inv(aligned)[:, :3, :4].astype(np.float32) | |
| def _chunk_ranges(frame_count, chunk_size, overlap): | |
| if chunk_size <= 0 or chunk_size >= frame_count: | |
| return [(0, frame_count)] | |
| if overlap < 8: | |
| raise ValueError("DA3 temporal chunking requires at least 8 overlap frames") | |
| if overlap >= chunk_size: | |
| raise ValueError("DA3 temporal chunk overlap must be smaller than the chunk size") | |
| ranges, start, step = [], 0, chunk_size - overlap | |
| while True: | |
| end = start + chunk_size | |
| if end >= frame_count: | |
| ranges.append((frame_count - chunk_size, frame_count)) | |
| break | |
| ranges.append((start, end)) | |
| next_start = start + step | |
| final_start = frame_count - chunk_size | |
| start = final_start if end - final_start >= overlap else next_start | |
| return ranges | |
| def _infer_da3_prediction(model, video, frame_indices, process_res): | |
| frames = [Image.fromarray(video[i]) for i in frame_indices] | |
| return model.inference(frames, process_res=process_res, export_format="npz") | |
| def _infer_da3_depth_prediction(model, video, frame_indices, process_res): | |
| frames = [Image.fromarray(video[i]) for i in frame_indices] | |
| prediction = model.inference(frames, process_res=process_res, export_format="npz") | |
| return _resize_2d(prediction.depth.astype(np.float32), video.shape[1], video.shape[2], mode="bilinear", inverse=True) | |
| def _run_da3_prediction(model, video, process_res, chunk_size=0, chunk_overlap=8): | |
| frame_count, height, width = video.shape[:3] | |
| chunk_size = resolve_da3_chunk_size(chunk_size) | |
| ranges = _chunk_ranges(frame_count, chunk_size, chunk_overlap) | |
| if len(ranges) == 1: | |
| prediction = _infer_da3_prediction(model, video, range(frame_count), process_res) | |
| depths, sky, cam_w2c, intrinsics = _prediction_to_arrays(prediction, height, width) | |
| return depths, sky, _camera_w2c_to_c2w(cam_w2c), intrinsics | |
| depths_all = np.empty((frame_count, height, width), dtype=np.float32) | |
| sky_all = np.empty((frame_count, height, width), dtype=np.bool_) | |
| cam_w2c_all = np.empty((frame_count, 3, 4), dtype=np.float32) | |
| intrinsics_all = np.empty((frame_count, 4), dtype=np.float32) | |
| filled = np.zeros(frame_count, dtype=np.bool_) | |
| for start, end in ranges: | |
| indices = np.arange(start, end) | |
| prediction = _infer_da3_prediction(model, video, indices, process_res) | |
| depths, sky, cam_w2c, intrinsics = _prediction_to_arrays(prediction, height, width) | |
| overlap_mask = filled[indices] | |
| if overlap_mask.any(): | |
| if int(overlap_mask.sum()) < 3: | |
| raise ValueError("DA3 temporal chunking produced fewer than 3 overlap frames for alignment") | |
| ref_w2c = cam_w2c_all[indices[overlap_mask]] | |
| est_w2c = cam_w2c[overlap_mask] | |
| rotation, translation, scale = _pose_based_chunk_alignment(ref_w2c, est_w2c) | |
| cam_w2c = _apply_sim3_to_w2c(cam_w2c, rotation, translation, scale) | |
| depths *= np.float32(scale) | |
| keep_mask = ~filled[indices] | |
| keep_indices = indices[keep_mask] | |
| depths_all[keep_indices] = depths[keep_mask] | |
| sky_all[keep_indices] = sky[keep_mask] | |
| cam_w2c_all[keep_indices] = cam_w2c[keep_mask] | |
| intrinsics_all[keep_indices] = intrinsics[keep_mask] | |
| filled[keep_indices] = True | |
| del prediction, depths, sky, cam_w2c, intrinsics | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if not filled.all(): | |
| missing = np.flatnonzero(~filled).tolist() | |
| raise RuntimeError(f"DA3 temporal chunking failed to fill frames: {missing}") | |
| return depths_all, sky_all, _camera_w2c_to_c2w(cam_w2c_all), intrinsics_all | |
| def _run_da3_depth_prediction(model, video, process_res, chunk_size=0): | |
| frame_count, height, width = video.shape[:3] | |
| chunk_size = resolve_da3_chunk_size(chunk_size) | |
| if chunk_size <= 0 or chunk_size >= frame_count: | |
| return _infer_da3_depth_prediction(model, video, range(frame_count), process_res) | |
| depth_all = np.empty((frame_count, height, width), dtype=np.float32) | |
| for start in range(0, frame_count, chunk_size): | |
| end = min(frame_count, start + chunk_size) | |
| depth_all[start:end] = _infer_da3_depth_prediction(model, video, range(start, end), process_res) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return depth_all | |
| def run_da3_reconstruction(video, pretrained_model=None, process_res=0, device=None, chunk_size=0, chunk_overlap=8): | |
| from shared.utils import files_locator as fl | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device | |
| chunk_size = resolve_da3_chunk_size(chunk_size, device) | |
| pretrained_model = pretrained_model or fl.locate_file(DA3_BF16_MODEL) | |
| model = _load_da3(pretrained_model, device, model_name="da3-large") | |
| height, width = video.shape[1:3] | |
| if process_res <= 0: | |
| process_res = width | |
| depths, sky, cam_c2w, intrinsics = _run_da3_prediction(model, video, process_res, chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| model.to("cpu") | |
| del model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return depths, sky, cam_c2w.astype(np.float32), intrinsics.astype(np.float32) | |
| class DepthV3VideoAnnotator: | |
| def __init__(self, cfg, device=None): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device | |
| self.process_res = int(cfg.get("PROCESS_RES", 0) or 0) | |
| self.chunk_size = resolve_da3_chunk_size(cfg.get("CHUNK_SIZE", -1), self.device) | |
| self.chunk_overlap = int(cfg.get("CHUNK_OVERLAP", 8) or 8) | |
| self.model_name = cfg.get("MODEL_NAME", "da3-large") | |
| self.model = _load_da3(cfg["PRETRAINED_MODEL"], self.device, model_name=self.model_name) | |
| def forward(self, frames): | |
| video = np.stack([np.asarray(frame) for frame in frames], axis=0) | |
| if self.model_name == "da3metric-large": | |
| depth = _run_da3_depth_prediction(self.model, video, self.process_res or video.shape[2], chunk_size=self.chunk_size) | |
| else: | |
| depth, _, _, _ = _run_da3_prediction(self.model, video, self.process_res or video.shape[2], chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) | |
| disp = 1.0 / np.maximum(depth, 1e-6) | |
| disp -= disp.min() | |
| disp /= max(float(disp.max()), 1e-6) | |
| depth_video = (disp * 255.0).clip(0, 255).astype(np.uint8) | |
| return [np.repeat(frame[..., None], 3, axis=2) for frame in depth_video] | |
| def close(self): | |
| if self.model is not None: | |
| self.model.to("cpu") | |
| self.model = None | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def __del__(self): | |
| try: | |
| self.close() | |
| except Exception: | |
| pass | |
Xet Storage Details
- Size:
- 12.9 kB
- Xet hash:
- 6eaabad210d2c7bf0e73f1377948b8fdfb12417b017fbb897443a7e333d00b20
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.