| | |
| |
|
| | from pathlib import Path |
| |
|
| | import pytorch_lightning as pl |
| | import torch |
| | from omegaconf import DictConfig, OmegaConf, open_dict |
| | from torchmetrics import MeanMetric, MetricCollection |
| |
|
| | from . import logger |
| | from .models import get_model |
| |
|
| |
|
| | class AverageKeyMeter(MeanMetric): |
| | def __init__(self, key, *args, **kwargs): |
| | self.key = key |
| | super().__init__(*args, **kwargs) |
| |
|
| | def update(self, dict): |
| | value = dict[self.key] |
| | value = value[torch.isfinite(value)] |
| | return super().update(value) |
| |
|
| |
|
| | class GenericModule(pl.LightningModule): |
| | def __init__(self, cfg): |
| | super().__init__() |
| | name = cfg.model.get("name") |
| | name = "map_perception_net" if name is None else name |
| | self.model = get_model(name)(cfg.model) |
| | self.cfg = cfg |
| | self.save_hyperparameters(cfg) |
| | self.metrics_val = MetricCollection( |
| | self.model.metrics(), prefix="val/") |
| | self.losses_val = None |
| |
|
| | def forward(self, batch): |
| | return self.model(batch) |
| |
|
| | def training_step(self, batch): |
| | pred = self(batch) |
| | losses = self.model.loss(pred, batch) |
| | self.log_dict( |
| | {f"train/loss/{k}": v.mean() for k, v in losses.items()}, |
| | prog_bar=True, |
| | rank_zero_only=True, |
| | on_epoch=True, |
| | sync_dist=True |
| | ) |
| | return losses["total"].mean() |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | pred = self(batch) |
| | losses = self.model.loss(pred, batch) |
| | if self.losses_val is None: |
| | self.losses_val = MetricCollection( |
| | {k: AverageKeyMeter(k).to(self.device) for k in losses}, |
| | prefix="val/", |
| | postfix="/loss", |
| | ) |
| | self.metrics_val(pred, batch) |
| | self.log_dict(self.metrics_val, on_epoch=True) |
| | self.losses_val.update(losses) |
| | self.log_dict(self.losses_val, on_epoch=True) |
| |
|
| | return pred |
| |
|
| | def test_step(self, batch, batch_idx): |
| | pred = self(batch) |
| |
|
| | return pred |
| |
|
| | def validation_epoch_start(self, batch): |
| | self.losses_val = None |
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.Adam( |
| | self.parameters(), lr=self.cfg.training.lr) |
| | ret = {"optimizer": optimizer} |
| | cfg_scheduler = self.cfg.training.get("lr_scheduler") |
| | if cfg_scheduler is not None: |
| | scheduler_args = cfg_scheduler.get("args", {}) |
| | for key in scheduler_args: |
| | if scheduler_args[key] == "$total_epochs": |
| | scheduler_args[key] = int(self.trainer.max_epochs) |
| | scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)( |
| | optimizer=optimizer, **scheduler_args |
| | ) |
| | ret["lr_scheduler"] = { |
| | "scheduler": scheduler, |
| | "interval": "epoch", |
| | "frequency": 1, |
| | "monitor": "loss/total/val", |
| | "strict": True, |
| | "name": "learning_rate", |
| | } |
| | return ret |
| |
|
| | @classmethod |
| | def load_from_checkpoint( |
| | cls, |
| | checkpoint_path, |
| | map_location=None, |
| | hparams_file=None, |
| | strict=True, |
| | cfg=None, |
| | find_best=False, |
| | ): |
| | assert hparams_file is None, "hparams are not supported." |
| |
|
| | checkpoint = torch.load( |
| | checkpoint_path, map_location=map_location or ( |
| | lambda storage, loc: storage) |
| | ) |
| | if find_best: |
| | best_score, best_name = None, None |
| | modes = {"min": torch.lt, "max": torch.gt} |
| | for key, state in checkpoint["callbacks"].items(): |
| | if not key.startswith("ModelCheckpoint"): |
| | continue |
| | mode = eval(key.replace("ModelCheckpoint", ""))["mode"] |
| | if best_score is None or modes[mode]( |
| | state["best_model_score"], best_score |
| | ): |
| | best_score = state["best_model_score"] |
| | best_name = Path(state["best_model_path"]).name |
| | logger.info("Loading best checkpoint %s", best_name) |
| | if best_name != checkpoint_path: |
| | return cls.load_from_checkpoint( |
| | Path(checkpoint_path).parent / best_name, |
| | map_location, |
| | hparams_file, |
| | strict, |
| | cfg, |
| | find_best=False, |
| | ) |
| |
|
| | logger.info( |
| | "Using checkpoint %s from epoch %d and step %d.", |
| | checkpoint_path, |
| | checkpoint["epoch"], |
| | checkpoint["global_step"], |
| | ) |
| | cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] |
| | if list(cfg_ckpt.keys()) == ["cfg"]: |
| | cfg_ckpt = cfg_ckpt["cfg"] |
| | cfg_ckpt = OmegaConf.create(cfg_ckpt) |
| |
|
| | if cfg is None: |
| | cfg = {} |
| | if not isinstance(cfg, DictConfig): |
| | cfg = OmegaConf.create(cfg) |
| | with open_dict(cfg_ckpt): |
| | cfg = OmegaConf.merge(cfg_ckpt, cfg) |
| |
|
| | return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg) |
| |
|