| | """Utility functions for visualization. |
| | |
| | For licensing see accompanying LICENSE file. |
| | Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import numpy as np |
| | import torch |
| | from matplotlib import pyplot as plt |
| |
|
| | METRIC_DEPTH_MAX_CLAMP_METER = 50.0 |
| |
|
| |
|
| | def colorize_depth(depth: torch.Tensor, val_max: float = 10.0) -> torch.Tensor: |
| | """Colorize depth map.""" |
| | depth_channels = depth.shape[-3] |
| |
|
| | |
| | if depth_channels == 1: |
| | return colorize_scalar_map( |
| | depth.squeeze(-3), val_min=0.0, val_max=val_max, color_map="turbo" |
| | ) |
| |
|
| | |
| | |
| | else: |
| | colored_depths = [] |
| | for c in range(depth_channels): |
| | colored_depths.append( |
| | colorize_scalar_map( |
| | depth[..., c, :, :], val_min=0.0, val_max=val_max, color_map="turbo" |
| | ) |
| | ) |
| | return torch.cat(colored_depths, dim=-1) |
| |
|
| |
|
| | def colorize_alpha(alpha: torch.Tensor) -> torch.Tensor: |
| | """Colorize alpha map.""" |
| | return colorize_scalar_map(alpha.squeeze(-3), val_min=0.0, val_max=1.0, color_map="coolwarm") |
| |
|
| |
|
| | def colorize_scalar_map( |
| | scalar_map: torch.Tensor, val_min=0.0, val_max=1.0, color_map: str = "jet" |
| | ) -> torch.Tensor: |
| | """Colorize a scalar map of. |
| | |
| | Args: |
| | scalar_map: Map of with format BHW. |
| | val_min: Minimu value to display. |
| | val_max: Maximum value to display. |
| | color_map: Which color map to use. Will be passed to matplotlob. |
| | |
| | Returns: |
| | A colorized image with format BHWC. |
| | """ |
| | if scalar_map.ndim not in (2, 3, 4): |
| | raise ValueError("Only scalar maps of 2 or 3 or 4 dimensions supported.") |
| |
|
| | cmap = plt.get_cmap(color_map) |
| |
|
| | scalar_map_np = scalar_map.detach().cpu().float().numpy() |
| | scalar_map_np = (scalar_map_np - val_min) / (val_max - val_min) |
| | scalar_map_np = np.clip(scalar_map_np, a_min=0.0, a_max=1.0) |
| |
|
| | color_map_np = cmap(scalar_map_np)[..., :3] |
| | tensor = torch.as_tensor(color_map_np * 255.0, dtype=torch.uint8) |
| |
|
| | if tensor.ndim == 3: |
| | return tensor.permute(2, 0, 1) |
| | elif tensor.ndim == 4: |
| | return tensor.permute(0, 3, 1, 2) |
| | elif tensor.ndim == 5: |
| | return tensor.permute(0, 1, 4, 2, 3) |
| | else: |
| | assert False, "Invalid tensor shape encountered." |
| |
|