Daankular's picture
download
raw
12.9 kB
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
_PACKAGE_ROOT = Path(__file__).resolve().parent
DA3_BF16_MODEL = "depth/depth_anything_v3_vitl_bf16.safetensors"
DA3_METRIC_BF16_MODEL = "depth/depth_anything_v3_metric_large_bf16.safetensors"
def resolve_da3_chunk_size(chunk_size=-1, device=None):
chunk_size = int(chunk_size if chunk_size is not None else -1)
if chunk_size != -1:
return chunk_size
if not torch.cuda.is_available():
return 33
device = torch.device("cuda" if device is None else device)
if device.type != "cuda":
return 33
device_index = torch.cuda.current_device() if device.index is None else device.index
vram_gb = torch.cuda.get_device_properties(device_index).total_memory / 1_000_000_000
if vram_gb < 8:
return 33
if vram_gb < 24:
return 65
return 97
def _load_da3(pretrained_model, device, model_name="da3-large"):
from mmgp import offload
from safetensors import safe_open
from .api import DepthAnything3
model = DepthAnything3(model_name=model_name)
pretrained_model = str(pretrained_model)
if not pretrained_model.endswith(".safetensors"):
raise ValueError(f"Depth Anything 3 now expects the bf16 safetensors checkpoint, got: {pretrained_model}")
model_keys = set(model.state_dict().keys())
with safe_open(pretrained_model, framework="pt", device="cpu") as f:
checkpoint_keys = set(f.keys())
missing = sorted(model_keys - checkpoint_keys)
unexpected = sorted(checkpoint_keys - model_keys)
allowed_missing = tuple(f"model.head.scratch.output_conv2_aux.{idx}.2." for idx in range(1, 4))
unsupported_missing = [key for key in missing if not key.startswith(allowed_missing)]
if unexpected or unsupported_missing:
raise RuntimeError(f"Unexpected DA3 checkpoint keys: unexpected={unexpected}, missing={unsupported_missing}")
offload.load_model_data(model, pretrained_model, writable_tensors=False, default_dtype=torch.bfloat16, ignore_missing_keys=True)
model.requires_grad_(False)
model.to(device=device, dtype=torch.bfloat16)
model.eval()
return model
def _resize_2d(array, height, width, mode="bilinear", inverse=False):
if array.shape[-2:] == (height, width):
return array.copy()
dtype = array.dtype
tensor = torch.from_numpy(array).to(torch.float64)
leading = tensor.shape[:-2]
tensor = tensor.reshape(-1, *tensor.shape[-2:])
if inverse:
tensor = 1 / tensor
tensor = F.interpolate(tensor[:, None], size=(height, width), mode=mode)[:, 0]
if inverse:
tensor = 1 / tensor
tensor = tensor.reshape(*leading, height, width)
if dtype == np.bool_:
tensor = tensor >= 0.5
return tensor.numpy().astype(dtype)
def _k_to_intrinsics(k):
intrinsics = np.zeros((k.shape[0], 4), dtype=np.float32)
intrinsics[:, 0] = k[:, 0, 0]
intrinsics[:, 1] = k[:, 1, 1]
intrinsics[:, 2] = k[:, 0, 2]
intrinsics[:, 3] = k[:, 1, 2]
return intrinsics
def _prediction_to_arrays(prediction, height, width):
depths = prediction.depth.astype(np.float32)
sky = getattr(prediction, "sky", None)
if sky is None:
sky = np.zeros_like(depths, dtype=np.bool_)
else:
sky = sky.astype(np.bool_)
cam_w2c = prediction.extrinsics.astype(np.float32)
intrinsics = _k_to_intrinsics(prediction.intrinsics.astype(np.float32))
processed = prediction.processed_images
proc_h, proc_w = processed.shape[1:3]
depths = _resize_2d(depths, height, width, mode="bilinear", inverse=True)
sky = _resize_2d(sky, height, width, mode="nearest", inverse=False)
intrinsics[:, 0::2] *= width / proc_w
intrinsics[:, 1::2] *= height / proc_h
return depths, sky, cam_w2c, intrinsics
def _camera_w2c_to_c2w(cam_w2c):
cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float32)
cam_w2c_44[:, :3, :4] = cam_w2c
cam_w2c_44[:, 3, 3] = 1.0
cam_c2w = np.linalg.inv(cam_w2c_44)
return (np.linalg.inv(cam_c2w[0])[None] @ cam_c2w).astype(np.float32)
def _w2c_to_pose(cam_w2c):
cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float64)
cam_w2c_44[:, :3, :4] = cam_w2c.astype(np.float64)
cam_w2c_44[:, 3, 3] = 1.0
return np.linalg.inv(cam_w2c_44)
def _closest_rotation(matrix):
u, _, vh = np.linalg.svd(matrix)
rotation = u @ vh
if np.linalg.det(rotation) < 0:
u[:, -1] *= -1
rotation = u @ vh
return rotation
def _pose_based_chunk_alignment(ref_w2c, est_w2c):
ref_pose = _w2c_to_pose(ref_w2c)
est_pose = _w2c_to_pose(est_w2c)
rotation = _closest_rotation(np.mean(ref_pose[:, :3, :3] @ np.swapaxes(est_pose[:, :3, :3], -1, -2), axis=0))
ref_centers = ref_pose[:, :3, 3]
est_centers = est_pose[:, :3, 3]
pair_i, pair_j = np.triu_indices(ref_centers.shape[0], k=1)
ref_dists = np.linalg.norm(ref_centers[pair_i] - ref_centers[pair_j], axis=1)
est_dists = np.linalg.norm(est_centers[pair_i] - est_centers[pair_j], axis=1)
valid = est_dists > np.finfo(np.float64).eps
scale = float(np.median(ref_dists[valid] / est_dists[valid])) if valid.any() else 1.0
est_mean = est_centers.mean(axis=0)
ref_mean = ref_centers.mean(axis=0)
translation = ref_mean - scale * (rotation @ est_mean)
return rotation.astype(np.float32), translation.astype(np.float32), np.float32(scale)
def _apply_sim3_to_w2c(cam_w2c, rotation, translation, scale):
cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float32)
cam_w2c_44[:, :3, :4] = cam_w2c
cam_w2c_44[:, 3, 3] = 1.0
poses = np.linalg.inv(cam_w2c_44)
aligned = poses.copy()
aligned[:, :3, :3] = rotation @ poses[:, :3, :3]
aligned[:, :3, 3] = (rotation @ (scale * poses[:, :3, 3]).T).T + translation
return np.linalg.inv(aligned)[:, :3, :4].astype(np.float32)
def _chunk_ranges(frame_count, chunk_size, overlap):
if chunk_size <= 0 or chunk_size >= frame_count:
return [(0, frame_count)]
if overlap < 8:
raise ValueError("DA3 temporal chunking requires at least 8 overlap frames")
if overlap >= chunk_size:
raise ValueError("DA3 temporal chunk overlap must be smaller than the chunk size")
ranges, start, step = [], 0, chunk_size - overlap
while True:
end = start + chunk_size
if end >= frame_count:
ranges.append((frame_count - chunk_size, frame_count))
break
ranges.append((start, end))
next_start = start + step
final_start = frame_count - chunk_size
start = final_start if end - final_start >= overlap else next_start
return ranges
def _infer_da3_prediction(model, video, frame_indices, process_res):
frames = [Image.fromarray(video[i]) for i in frame_indices]
return model.inference(frames, process_res=process_res, export_format="npz")
def _infer_da3_depth_prediction(model, video, frame_indices, process_res):
frames = [Image.fromarray(video[i]) for i in frame_indices]
prediction = model.inference(frames, process_res=process_res, export_format="npz")
return _resize_2d(prediction.depth.astype(np.float32), video.shape[1], video.shape[2], mode="bilinear", inverse=True)
def _run_da3_prediction(model, video, process_res, chunk_size=0, chunk_overlap=8):
frame_count, height, width = video.shape[:3]
chunk_size = resolve_da3_chunk_size(chunk_size)
ranges = _chunk_ranges(frame_count, chunk_size, chunk_overlap)
if len(ranges) == 1:
prediction = _infer_da3_prediction(model, video, range(frame_count), process_res)
depths, sky, cam_w2c, intrinsics = _prediction_to_arrays(prediction, height, width)
return depths, sky, _camera_w2c_to_c2w(cam_w2c), intrinsics
depths_all = np.empty((frame_count, height, width), dtype=np.float32)
sky_all = np.empty((frame_count, height, width), dtype=np.bool_)
cam_w2c_all = np.empty((frame_count, 3, 4), dtype=np.float32)
intrinsics_all = np.empty((frame_count, 4), dtype=np.float32)
filled = np.zeros(frame_count, dtype=np.bool_)
for start, end in ranges:
indices = np.arange(start, end)
prediction = _infer_da3_prediction(model, video, indices, process_res)
depths, sky, cam_w2c, intrinsics = _prediction_to_arrays(prediction, height, width)
overlap_mask = filled[indices]
if overlap_mask.any():
if int(overlap_mask.sum()) < 3:
raise ValueError("DA3 temporal chunking produced fewer than 3 overlap frames for alignment")
ref_w2c = cam_w2c_all[indices[overlap_mask]]
est_w2c = cam_w2c[overlap_mask]
rotation, translation, scale = _pose_based_chunk_alignment(ref_w2c, est_w2c)
cam_w2c = _apply_sim3_to_w2c(cam_w2c, rotation, translation, scale)
depths *= np.float32(scale)
keep_mask = ~filled[indices]
keep_indices = indices[keep_mask]
depths_all[keep_indices] = depths[keep_mask]
sky_all[keep_indices] = sky[keep_mask]
cam_w2c_all[keep_indices] = cam_w2c[keep_mask]
intrinsics_all[keep_indices] = intrinsics[keep_mask]
filled[keep_indices] = True
del prediction, depths, sky, cam_w2c, intrinsics
if torch.cuda.is_available():
torch.cuda.empty_cache()
if not filled.all():
missing = np.flatnonzero(~filled).tolist()
raise RuntimeError(f"DA3 temporal chunking failed to fill frames: {missing}")
return depths_all, sky_all, _camera_w2c_to_c2w(cam_w2c_all), intrinsics_all
def _run_da3_depth_prediction(model, video, process_res, chunk_size=0):
frame_count, height, width = video.shape[:3]
chunk_size = resolve_da3_chunk_size(chunk_size)
if chunk_size <= 0 or chunk_size >= frame_count:
return _infer_da3_depth_prediction(model, video, range(frame_count), process_res)
depth_all = np.empty((frame_count, height, width), dtype=np.float32)
for start in range(0, frame_count, chunk_size):
end = min(frame_count, start + chunk_size)
depth_all[start:end] = _infer_da3_depth_prediction(model, video, range(start, end), process_res)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return depth_all
@torch.inference_mode()
def run_da3_reconstruction(video, pretrained_model=None, process_res=0, device=None, chunk_size=0, chunk_overlap=8):
from shared.utils import files_locator as fl
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
chunk_size = resolve_da3_chunk_size(chunk_size, device)
pretrained_model = pretrained_model or fl.locate_file(DA3_BF16_MODEL)
model = _load_da3(pretrained_model, device, model_name="da3-large")
height, width = video.shape[1:3]
if process_res <= 0:
process_res = width
depths, sky, cam_c2w, intrinsics = _run_da3_prediction(model, video, process_res, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
model.to("cpu")
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return depths, sky, cam_c2w.astype(np.float32), intrinsics.astype(np.float32)
class DepthV3VideoAnnotator:
def __init__(self, cfg, device=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.process_res = int(cfg.get("PROCESS_RES", 0) or 0)
self.chunk_size = resolve_da3_chunk_size(cfg.get("CHUNK_SIZE", -1), self.device)
self.chunk_overlap = int(cfg.get("CHUNK_OVERLAP", 8) or 8)
self.model_name = cfg.get("MODEL_NAME", "da3-large")
self.model = _load_da3(cfg["PRETRAINED_MODEL"], self.device, model_name=self.model_name)
@torch.inference_mode()
def forward(self, frames):
video = np.stack([np.asarray(frame) for frame in frames], axis=0)
if self.model_name == "da3metric-large":
depth = _run_da3_depth_prediction(self.model, video, self.process_res or video.shape[2], chunk_size=self.chunk_size)
else:
depth, _, _, _ = _run_da3_prediction(self.model, video, self.process_res or video.shape[2], chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
disp = 1.0 / np.maximum(depth, 1e-6)
disp -= disp.min()
disp /= max(float(disp.max()), 1e-6)
depth_video = (disp * 255.0).clip(0, 255).astype(np.uint8)
return [np.repeat(frame[..., None], 3, axis=2) for frame in depth_video]
def close(self):
if self.model is not None:
self.model.to("cpu")
self.model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
def __del__(self):
try:
self.close()
except Exception:
pass

Xet Storage Details

Size:
12.9 kB
·
Xet hash:
6eaabad210d2c7bf0e73f1377948b8fdfb12417b017fbb897443a7e333d00b20

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.