| """Contains params for backbone. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| import dataclasses |
| from typing import Literal |
|
|
| import sharp.utils.math as math_utils |
| from sharp.models.blocks import NormLayerName, UpsamplingMode |
| from sharp.models.presets import ViTPreset |
| from sharp.utils.color_space import ColorSpace |
|
|
| DimsDecoder = tuple[int, int, int, int, int] |
| DPTImageEncoderType = Literal["skip_conv", "skip_conv_kernel2"] |
|
|
| ColorInitOption = Literal[ |
| "none", |
| "first_layer", |
| "all_layers", |
| ] |
| DepthInitOption = Literal[ |
| |
| "surface_min", |
| |
| "surface_max", |
| |
| "base_depth", |
| |
| "linear_disparity", |
| ] |
|
|
|
|
| @dataclasses.dataclass |
| class AlignmentParams: |
| """Parameters for depth alignment.""" |
|
|
| kernel_size: int = 16 |
| stride: int = 1 |
| frozen: bool = False |
|
|
| |
| |
| steps: int = 4 |
| |
| activation_type: math_utils.ActivationType = "exp" |
| |
| depth_decoder_features: bool = False |
| |
| base_width: int = 16 |
|
|
|
|
| @dataclasses.dataclass |
| class DeltaFactor: |
| """Factors to multiply deltas with before activation. |
| |
| These factors effectively selectively reduce the learning rate. |
| """ |
|
|
| xy: float = 0.001 |
| z: float = 0.001 |
| color: float = 0.1 |
| opacity: float = 1.0 |
| scale: float = 1.0 |
| quaternion: float = 1.0 |
|
|
|
|
| @dataclasses.dataclass |
| class InitializerParams: |
| """Parameters for initializer.""" |
|
|
| |
| |
| scale_factor: float = 1.0 |
| |
| disparity_factor: float = 1.0 |
| |
| stride: int = 2 |
|
|
| |
| |
| num_layers: int = 2 |
| |
| first_layer_depth_option: DepthInitOption = "surface_min" |
| rest_layer_depth_option: DepthInitOption = "surface_min" |
| |
| color_option: ColorInitOption = "all_layers" |
| |
| base_depth: float = 10.0 |
| |
| feature_input_stop_grad: bool = False |
| |
| |
| normalize_depth: bool = True |
|
|
| |
| output_inpainted_layer_only: bool = False |
| |
| set_uninpainted_opacity_to_zero: bool = False |
| |
| concat_inpainting_mask: bool = False |
|
|
|
|
| @dataclasses.dataclass |
| class MonodepthParams: |
| """Parameters for monodepth network.""" |
|
|
| patch_encoder_preset: ViTPreset = "dinov2l16_384" |
| image_encoder_preset: ViTPreset = "dinov2l16_384" |
|
|
| checkpoint_uri: str | None = None |
| unfreeze_patch_encoder: bool = False |
| unfreeze_image_encoder: bool = False |
| unfreeze_decoder: bool = False |
| unfreeze_head: bool = False |
| unfreeze_norm_layers: bool = False |
| grad_checkpointing: bool = False |
| use_patch_overlap: bool = True |
| dims_decoder: DimsDecoder = (256, 256, 256, 256, 256) |
|
|
|
|
| @dataclasses.dataclass |
| class MonodepthAdaptorParams: |
| """Parameters for monodepth network feature adaptor.""" |
|
|
| encoder_features: bool = True |
| decoder_features: bool = False |
|
|
|
|
| @dataclasses.dataclass |
| class GaussianDecoderParams: |
| """Parameters for backbone with default values.""" |
|
|
| dim_in: int = 5 |
| dim_out: int = 32 |
| |
| norm_type: NormLayerName = "group_norm" |
| |
| norm_num_groups: int = 8 |
| |
| stride: int = 2 |
|
|
| patch_encoder_preset: ViTPreset = "dinov2l16_384" |
| image_encoder_preset: ViTPreset = "dinov2l16_384" |
|
|
| |
| dims_decoder: DimsDecoder = (128, 128, 128, 128, 128) |
|
|
| |
| use_depth_input: bool = True |
|
|
| |
| grad_checkpointing: bool = False |
|
|
| |
| upsampling_mode: UpsamplingMode = "transposed_conv" |
|
|
| |
| image_encoder_type: DPTImageEncoderType = "skip_conv_kernel2" |
|
|
|
|
| @dataclasses.dataclass |
| class PredictorParams: |
| """Parameters for predictors with default values.""" |
|
|
| |
| initializer: InitializerParams = dataclasses.field(default_factory=InitializerParams) |
| monodepth: MonodepthParams = dataclasses.field(default_factory=MonodepthParams) |
| monodepth_adaptor: MonodepthAdaptorParams = dataclasses.field( |
| default_factory=MonodepthAdaptorParams |
| ) |
| gaussian_decoder: GaussianDecoderParams = dataclasses.field( |
| default_factory=GaussianDecoderParams |
| ) |
| |
| depth_alignment: AlignmentParams = dataclasses.field(default_factory=AlignmentParams) |
|
|
| |
| delta_factor: DeltaFactor = dataclasses.field(default_factory=DeltaFactor) |
| |
| max_scale: float = 10.0 |
| |
| min_scale: float = 0.0 |
| |
| norm_type: NormLayerName = "group_norm" |
| |
| norm_num_groups: int = 8 |
| |
| use_predicted_mean: bool = False |
| |
| color_activation_type: math_utils.ActivationType = "sigmoid" |
| opacity_activation_type: math_utils.ActivationType = "sigmoid" |
| |
| color_space: ColorSpace = "linearRGB" |
| |
| low_pass_filter_eps: float = 1e-2 |
| |
| num_monodepth_layers: int = 2 |
| |
| sorting_monodepth: bool = False |
| |
| base_scale_on_predicted_mean: bool = True |
|
|