Learn2Splat / optgs /scripts /dev /debug_dataset.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
Raw
History Blame Contribute Delete
2.97 kB
import os
import sys
import warnings
from pathlib import Path
import hydra
import torch
from jaxtyping import install_import_hook
from omegaconf import DictConfig
import matplotlib.pyplot as plt
from optgs.misc.io import cyan
# Configure beartype and jaxtyping.
with install_import_hook(
("optgs",),
("beartype", "beartype"),
):
from optgs.config import setup_cfg
from optgs.dataset.data_module import DataModule
from optgs.misc.step_tracker import StepTracker
# print torch device info
print(cyan(f"Torch version: {torch.__version__}"))
if torch.cuda.is_available():
print(cyan(f"CUDA is available. Number of devices: {torch.cuda.device_count()}"))
for i in range(torch.cuda.device_count()):
print(cyan(f"Device {i}: {torch.cuda.get_device_name(i)}"))
else:
print(cyan("CUDA is not available."))
# raise ValueError("CUDA is required to run this code.")
@hydra.main(
version_base=None,
config_path="config",
config_name="main",
)
def train(cfg_dict: DictConfig):
# Set up configuration.
cfg, cfg_dict, eval_cfg = setup_cfg(cfg_dict)
# This allows the current step to be shared with the data loader processes.
step_tracker = StepTracker()
data_module = DataModule(
cfg.dataset,
cfg.data_loader,
step_tracker,
)
if cfg.mode == "train":
print("train:", len(data_module.train_dataloader()))
print("val:", len(data_module.val_dataloader()))
print("test:", len(data_module.test_dataloader()))
else:
print("test:", len(data_module.test_dataloader()))
# DEBUGGING: loop over all data once to catch errors early
for batch_idx, batch in enumerate(data_module.test_dataloader()):
extrinsics = batch["context"]["extrinsics"]
pose_norm = extrinsics.view(extrinsics.shape[0], -1).norm(dim=1)
if pose_norm > 1e3:
print(f"Batch {batch_idx}: pose norm {pose_norm.item():.4f} {extrinsics} {batch['scene']} {batch['context']['index']}")
image = batch["context"]["image"][0, 0].permute(1, 2, 0).cpu().numpy()
plt.figure()
plt.imshow(image)
plt.title(f"Batch {batch_idx}\n{batch['scene'][0]}")
plt.show()
print(cyan("DEBUG: Completed one full pass through the data loaders without errors. Exiting now."))
sys.exit(0)
if __name__ == "__main__":
warnings.filterwarnings("ignore")
torch.set_float32_matmul_precision('high')
if not torch.cuda.is_available():
print("")
print(cyan("=" * 80))
print(cyan("CUDA is not available, running on CPU."))
print(cyan("=" * 80))
print("")
# Print the hostname and current working directory.
print(cyan("=" * 80))
print(cyan(f"Starting training on {os.uname().nodename}, slurm job id: {os.environ.get('SLURM_JOB_ID', 'N/A')}"))
print(cyan(f"Current working directory: {Path.cwd()}"))
print(cyan("=" * 80))
train()