Spaces:
Running
Running
| import os | |
| from models.module import DiffGeolocalizer | |
| import hydra | |
| import wandb | |
| from os.path import isfile, join | |
| from shutil import copyfile | |
| import torch | |
| from omegaconf import OmegaConf | |
| from omegaconf import open_dict | |
| from hydra.core.hydra_config import HydraConfig | |
| from hydra.utils import instantiate | |
| from pytorch_lightning.callbacks import LearningRateMonitor | |
| from lightning_fabric.utilities.rank_zero import _get_rank | |
| from models.module import DiffGeolocalizer | |
| torch.set_float32_matmul_precision("high") # TODO do we need that? | |
| # Registering the "eval" resolver allows for advanced config | |
| # interpolation with arithmetic operations in hydra: | |
| # https://omegaconf.readthedocs.io/en/2.3_branch/how_to_guides.html | |
| OmegaConf.register_new_resolver("eval", eval) | |
| def load_model(cfg, dict_config, wandb_id): | |
| logger = instantiate(cfg.logger, id=open(wandb_id, "r").read(), resume="allow") | |
| model = DiffGeolocalizer.load_from_checkpoint(cfg.checkpoint, cfg=cfg.model) | |
| trainer = instantiate(cfg.trainer, strategy=cfg.trainer.strategy, logger=logger) | |
| return trainer, model | |
| def hydra_boilerplate(cfg): | |
| dict_config = OmegaConf.to_container(cfg, resolve=True) | |
| trainer, model = load_model(cfg, dict_config, cfg.wandb_id) | |
| return trainer, model | |
| import copy | |
| def generate_datamodules(cfg_): | |
| for f in os.listdir(cfg_.test_dir): | |
| cfg = copy.deepcopy(cfg_) | |
| # open join(f, directory) with OmegaConf | |
| with open_dict(cfg): | |
| cfg_new = OmegaConf.load(join(cfg.test_dir, f)) | |
| cfg.datamodule = cfg_new.datamodule | |
| cfg.dataset = cfg_new.dataset | |
| cfg.dataset.test_transform = cfg_.dataset.test_transform | |
| datamodule = instantiate(cfg.datamodule) | |
| yield datamodule | |
| if __name__ == "__main__": | |
| import sys | |
| sys.argv = ( | |
| [sys.argv[0]] | |
| + ["+pt_model_path=${hydra:runtime.config_sources}"] | |
| + sys.argv[1:] | |
| ) | |
| def main(cfg): | |
| # print(hydra.runtime.config_sources) | |
| with open_dict(cfg): | |
| path = cfg.pt_model_path[1]["path"] | |
| cfg.wandb_id = join(path, "wandb_id.txt") | |
| cfg.checkpoint = join(path, "last.ckpt") | |
| cfg.computer.devices = 1 | |
| ( | |
| trainer, | |
| model, | |
| ) = hydra_boilerplate(cfg) | |
| for datamodule in generate_datamodules(cfg): | |
| model.datamodule = datamodule | |
| model.datamodule.setup() | |
| print("Testing on", datamodule.test_dataset.class_name) | |
| trainer.test(model, datamodule=datamodule) | |
| main() | |