| | """Contains modules to initialize Gaussians from RGBD. |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import NamedTuple |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from .params import ColorInitOption, DepthInitOption, InitializerParams |
| |
|
| |
|
| | def create_initializer(params: InitializerParams) -> nn.Module: |
| | """Create inpainter.""" |
| | return MultiLayerInitializer( |
| | num_layers=params.num_layers, |
| | stride=params.stride, |
| | base_depth=params.base_depth, |
| | scale_factor=params.scale_factor, |
| | disparity_factor=params.disparity_factor, |
| | color_option=params.color_option, |
| | first_layer_depth_option=params.first_layer_depth_option, |
| | rest_layer_depth_option=params.rest_layer_depth_option, |
| | normalize_depth=params.normalize_depth, |
| | feature_input_stop_grad=params.feature_input_stop_grad, |
| | ) |
| |
|
| |
|
| | class GaussianBaseValues(NamedTuple): |
| | """Base values for gaussian predictor. |
| | |
| | We predict x and y in normalized device coordinates (NDC) where (-1, -1) is the top |
| | left corner and (1, 1) the bottom right corner. The last component of |
| | mean_vectors_ndc is inverse depth. |
| | """ |
| |
|
| | mean_x_ndc: torch.Tensor |
| | mean_y_ndc: torch.Tensor |
| | mean_inverse_z_ndc: torch.Tensor |
| |
|
| | scales: torch.Tensor |
| | quaternions: torch.Tensor |
| | colors: torch.Tensor |
| | opacities: torch.Tensor |
| |
|
| |
|
| | class InitializerOutput(NamedTuple): |
| | """Output of initializer.""" |
| |
|
| | |
| | gaussian_base_values: GaussianBaseValues |
| |
|
| | |
| | feature_input: torch.Tensor |
| |
|
| | |
| | global_scale: torch.Tensor | None = None |
| |
|
| |
|
| | class MultiLayerInitializer(nn.Module): |
| | """Initialize Gaussians with multilayer representation. |
| | |
| | The returned tensors have the shape |
| | |
| | batch_size x dim x num_layers x height x width |
| | |
| | where dim indicates the dimensionality of the property. |
| | Some of the dimensions might be set to 1 for efficiency reasons. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_layers: int, |
| | stride: int, |
| | base_depth: float, |
| | scale_factor: float, |
| | disparity_factor: float, |
| | color_option: ColorInitOption = "first_layer", |
| | first_layer_depth_option: DepthInitOption = "surface_min", |
| | rest_layer_depth_option: DepthInitOption = "surface_min", |
| | normalize_depth: bool = True, |
| | feature_input_stop_grad: bool = True, |
| | ) -> None: |
| | """Initialize MultilayerInitializer. |
| | |
| | Args: |
| | stride: The downsample rate of output feature map. |
| | base_depth: The depth of the first layer (after the foreground |
| | layer if use_depth=True). |
| | scale_factor: Multiply scale of Gaussians by this factor. |
| | disparity_factor: Factor to convert inverse depth to disparity. |
| | num_layers: How many layers of Gaussians to predict. |
| | color_option: Which color option to initialize the multi-layer gaussians. |
| | first_layer_depth_option: Which depth option to initialize the first layer of gaussians. |
| | rest_layer_depth_option: Which depth option to initialize the rest layers of gaussians. |
| | normalize_depth: # Whether to normalize depth to [DepthTransformParam.depth_min, |
| | DepthTransformParam.depth_max). |
| | feature_input_stop_grad: Whether to not propagate gradients through feature inputs. |
| | """ |
| | super().__init__() |
| | self.num_layers = num_layers |
| | self.stride = stride |
| | self.base_depth = base_depth |
| | self.scale_factor = scale_factor |
| | self.disparity_factor = disparity_factor |
| | self.color_option = color_option |
| | self.first_layer_depth_option = first_layer_depth_option |
| | self.rest_layer_depth_option = rest_layer_depth_option |
| | self.normalize_depth = normalize_depth |
| | self.feature_input_stop_grad = feature_input_stop_grad |
| |
|
| | def prepare_feature_input(self, image: torch.Tensor, depth: torch.Tensor) -> torch.Tensor: |
| | """Prepare the feature input to the Guassian predictor.""" |
| | if self.feature_input_stop_grad: |
| | image = image.detach() |
| | depth = depth.detach() |
| |
|
| | normalized_disparity = self.disparity_factor / depth |
| | features_in = torch.cat([image, normalized_disparity], dim=1) |
| | features_in = 2.0 * features_in - 1.0 |
| | return features_in |
| |
|
| | def forward(self, image: torch.Tensor, depth: torch.Tensor) -> InitializerOutput: |
| | """Construct Gaussian base values and prepare feature input. |
| | |
| | Args: |
| | image: The image to process. |
| | depth: The corresponding depth map from the monodepth network. |
| | |
| | Returns: |
| | The base value for Gaussians. |
| | """ |
| | image = image.contiguous() |
| | depth = depth.contiguous() |
| | device = depth.device |
| | batch_size, _, image_height, image_width = depth.shape |
| | base_height, base_width = ( |
| | image_height // self.stride, |
| | image_width // self.stride, |
| | ) |
| | |
| | |
| | global_scale: torch.Tensor | None = None |
| | if self.normalize_depth: |
| | depth, depth_factor = _rescale_depth(depth) |
| | global_scale = 1.0 / depth_factor |
| |
|
| | def _create_disparity_layers(num_layers: int = 1) -> torch.Tensor: |
| | """Create multiple disparity layers.""" |
| | disparity = torch.linspace(1.0 / self.base_depth, 0.0, num_layers + 1, device=device) |
| | return disparity[None, None, :-1, None, None].repeat( |
| | batch_size, 1, 1, base_height, base_width |
| | ) |
| |
|
| | def _create_surface_layer( |
| | depth: torch.Tensor, |
| | depth_pooling_mode: str, |
| | ) -> torch.Tensor: |
| | """Create multiple surface layers.""" |
| | disparity = 1.0 / depth |
| | if depth_pooling_mode == "min": |
| | disparity = torch.max_pool2d(disparity, self.stride, self.stride) |
| | elif depth_pooling_mode == "max": |
| | disparity = -torch.max_pool2d(-disparity, self.stride, self.stride) |
| | else: |
| | raise ValueError(f"Invalid depth pooling mode {depth_pooling_mode}.") |
| |
|
| | return disparity[:, :, None, :, :] |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | if self.first_layer_depth_option == "surface_min": |
| | first_disparity = _create_surface_layer(depth[:, 0:1], "min") |
| | elif self.first_layer_depth_option == "surface_max": |
| | first_disparity = _create_surface_layer(depth[:, 0:1], "max") |
| | elif self.first_layer_depth_option in ("base_depth", "linear_disparity"): |
| | first_disparity = _create_disparity_layers() |
| | else: |
| | raise ValueError(f"Unknown depth init option: {self.first_layer_depth_option}.") |
| |
|
| | if self.num_layers == 1: |
| | disparity = first_disparity |
| | else: |
| | following_depth = depth if depth.shape[1] == 1 else depth[:, 1:] |
| | if self.rest_layer_depth_option == "surface_min": |
| | following_disparity = _create_surface_layer(following_depth, "min") |
| | elif self.rest_layer_depth_option == "surface_max": |
| | following_disparity = _create_surface_layer(following_depth, "max") |
| | elif self.rest_layer_depth_option == "base_depth": |
| | following_disparity = torch.cat( |
| | [_create_disparity_layers() for i in range(self.num_layers - 1)], |
| | dim=2, |
| | ) |
| | elif self.rest_layer_depth_option == "linear_disparity": |
| | following_disparity = _create_disparity_layers(self.num_layers - 1) |
| | else: |
| | raise ValueError(f"Unknown depth init option: {self.rest_layer_depth_option}.") |
| |
|
| | disparity = torch.cat([first_disparity, following_disparity], dim=2) |
| |
|
| | |
| | base_x_ndc, base_y_ndc = _create_base_xy(depth, self.stride, self.num_layers) |
| | disparity_scale_factor = 2 * self.scale_factor * self.stride / float(image_width) |
| | base_scales = _create_base_scale(disparity, disparity_scale_factor) |
| |
|
| | base_quaternions = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device) |
| | base_quaternions = base_quaternions[None, :, None, None, None] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | base_opacities = torch.tensor([min(1.0 / self.num_layers, 0.5)], device=device) |
| | base_colors = torch.empty( |
| | batch_size, 3, self.num_layers, base_height, base_width, device=device |
| | ).fill_(0.5) |
| | |
| | if self.color_option == "none": |
| | pass |
| | elif self.color_option == "first_layer": |
| | base_colors[:, :, 0] = torch.nn.functional.avg_pool2d(image, self.stride, self.stride) |
| | elif self.color_option == "all_layers": |
| | temp = torch.nn.functional.avg_pool2d(image, self.stride, self.stride) |
| | base_colors = temp[:, :, None, :, :].repeat(1, 1, self.num_layers, 1, 1) |
| | else: |
| | raise ValueError(f"Unknown color init option: {self.color_option}.") |
| |
|
| | features_in = self.prepare_feature_input(image, depth) |
| | base_gaussians = GaussianBaseValues( |
| | mean_x_ndc=base_x_ndc, |
| | mean_y_ndc=base_y_ndc, |
| | mean_inverse_z_ndc=disparity, |
| | scales=base_scales, |
| | quaternions=base_quaternions, |
| | colors=base_colors, |
| | opacities=base_opacities, |
| | ) |
| |
|
| | return InitializerOutput( |
| | gaussian_base_values=base_gaussians, |
| | feature_input=features_in, |
| | global_scale=global_scale, |
| | ) |
| |
|
| |
|
| | def _create_base_xy( |
| | depth: torch.Tensor, stride: int, num_layers: int |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Create base x and y coordinates for the gaussians in NDC space.""" |
| | device = depth.device |
| | batch_size, _, image_height, image_width = depth.shape |
| | xx = torch.arange(0.5 * stride, image_width, stride, device=device) |
| | yy = torch.arange(0.5 * stride, image_height, stride, device=device) |
| | xx = 2 * xx / image_width - 1.0 |
| | yy = 2 * yy / image_height - 1.0 |
| |
|
| | xx, yy = torch.meshgrid(xx, yy, indexing="xy") |
| | base_x_ndc = xx[None, None, None].repeat(batch_size, 1, num_layers, 1, 1) |
| | base_y_ndc = yy[None, None, None].repeat(batch_size, 1, num_layers, 1, 1) |
| |
|
| | return base_x_ndc, base_y_ndc |
| |
|
| |
|
| | def _create_base_scale(disparity: torch.Tensor, disparity_scale_factor: float) -> torch.Tensor: |
| | """Create base scale for the gaussians.""" |
| | inverse_disparity = torch.ones_like(disparity) / disparity |
| | base_scales = inverse_disparity * disparity_scale_factor |
| | return base_scales |
| |
|
| |
|
| | def _rescale_depth( |
| | depth: torch.Tensor, depth_min: float = 1.0, depth_max: float = 1e2 |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Rescale a depth image tensor. |
| | |
| | Args: |
| | depth: The depth tensor to transform. |
| | depth_min: The min depth to scale depth to. |
| | depth_max: The max clamp depth after scaling. |
| | |
| | Returns: |
| | The rescaled depth and rescale factor. |
| | """ |
| | current_depth_min = depth.flatten(depth.ndim - 3).min(dim=-1).values |
| | depth_factor = depth_min / (current_depth_min + 1e-6) |
| | depth = (depth * depth_factor[..., None, None, None]).clamp(max=depth_max) |
| | return depth, depth_factor |
| |
|