| | import lightning.pytorch as pl |
| | from lightning.pytorch.callbacks import ( |
| | Callback, LearningRateMonitor, DeviceStatsMonitor, |
| | EarlyStopping, ModelCheckpoint |
| | ) |
| | from lightning.pytorch import Trainer, seed_everything, loggers as pl_loggers |
| | from lightning.pytorch.strategies import DDPStrategy |
| | from lightning.pytorch.utilities import grad_norm |
| |
|
| | import torch |
| | from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR |
| | import torchmetrics |
| | import numpy as np |
| | from einops import rearrange |
| |
|
| | from omegaconf import OmegaConf |
| | import os |
| | from shutil import copyfile |
| | import inspect |
| |
|
| | from models.vae import ( |
| | AutoencoderKL, |
| | LPIPSWithDiscriminator |
| | ) |
| | from utils.path import default_exps_dir |
| | from utils.optim import warmup_lambda |
| | from datamodule import SEVIRLightningDataModule,vis_sevir_seq |
| |
|
| |
|
| | class VAESEVIRPLModule(pl.LightningModule): |
| | def __init__( |
| | self, |
| | total_num_steps: int, |
| | accumulate_grad_batches: int = 1, |
| | oc_file: str = None, |
| | save_dir: str = None |
| | ): |
| | super(VAESEVIRPLModule, self).__init__() |
| | oc_from_file = OmegaConf.load(open(oc_file, "r")) if oc_file is not None else None |
| | oc = self.get_base_config(oc_from_file=oc_from_file) |
| | |
| | model_cfg = OmegaConf.to_object(oc.model) |
| | self.torch_nn_module = AutoencoderKL( |
| | down_block_types=model_cfg["down_block_types"], |
| | in_channels=model_cfg["in_channels"], |
| | sample_size=model_cfg["sample_size"], |
| | block_out_channels=model_cfg["block_out_channels"], |
| | act_fn=model_cfg["act_fn"], |
| | latent_channels=model_cfg["latent_channels"], |
| | up_block_types=model_cfg["up_block_types"], |
| | norm_num_groups=model_cfg["norm_num_groups"], |
| | layers_per_block=model_cfg["layers_per_block"], |
| | out_channels=model_cfg["out_channels"], |
| | ) |
| | |
| | loss_cfg = model_cfg["loss"] |
| | self.loss = LPIPSWithDiscriminator( |
| | disc_start=loss_cfg["disc_start"], |
| | kl_weight=loss_cfg["kl_weight"], |
| | disc_weight=loss_cfg["disc_weight"], |
| | perceptual_weight=loss_cfg["perceptual_weight"], |
| | disc_in_channels=loss_cfg["disc_in_channels"], |
| | ) |
| |
|
| | self.total_num_steps = total_num_steps |
| | self.save_hyperparameters(oc) |
| | self.oc = oc |
| | |
| | |
| | self.layout = oc.layout.layout |
| | self.channel_axis = self.layout.find("C") |
| | self.batch_axis = self.layout.find("N") |
| | self.t_axis = self.layout.find("T") |
| | self.h_axis = self.layout.find("H") |
| | self.w_axis = self.layout.find("W") |
| | self.channels = model_cfg["data_channels"] |
| | |
| | |
| | |
| | self.automatic_optimization = False |
| | self.accumulate_grad_batches = accumulate_grad_batches |
| | self.max_epochs = oc.optim.max_epochs |
| | self.optim_method = oc.optim.method |
| | self.lr = oc.optim.lr |
| | self.wd = oc.optim.wd |
| | |
| | |
| | self.total_num_steps = total_num_steps |
| | |
| | |
| | self.save_dir = save_dir |
| | self.logging_prefix = oc.logging.logging_prefix |
| | |
| | |
| | self.train_example_data_idx_list = list(oc.eval.train_example_data_idx_list) |
| | self.val_example_data_idx_list = list(oc.eval.val_example_data_idx_list) |
| | self.test_example_data_idx_list = list(oc.eval.test_example_data_idx_list) |
| | self.eval_example_only = oc.eval.eval_example_only |
| | |
| |
|
| | self.valid_mse = torchmetrics.MeanSquaredError() |
| | self.valid_mae = torchmetrics.MeanAbsoluteError() |
| | self.test_mse = torchmetrics.MeanSquaredError() |
| | self.test_mae = torchmetrics.MeanAbsoluteError() |
| |
|
| | self.configure_save(cfg_file_path=oc_file) |
| |
|
| | def configure_save(self, cfg_file_path=None): |
| | self.save_dir = os.path.join(default_exps_dir, self.save_dir) |
| | os.makedirs(self.save_dir, exist_ok=True) |
| | self.scores_dir = os.path.join(self.save_dir, 'scores') |
| | os.makedirs(self.scores_dir, exist_ok=True) |
| | if cfg_file_path is not None: |
| | cfg_file_target_path = os.path.join(self.save_dir, "cfg.yaml") |
| | if (not os.path.exists(cfg_file_target_path)) or \ |
| | (not os.path.samefile(cfg_file_path, cfg_file_target_path)): |
| | copyfile(cfg_file_path, cfg_file_target_path) |
| | self.example_save_dir = os.path.join(self.save_dir, "examples") |
| | os.makedirs(self.example_save_dir, exist_ok=True) |
| |
|
| | |
| | def get_base_config(self, oc_from_file=None): |
| | oc = OmegaConf.create() |
| | oc.layout = VAESEVIRPLModule.get_layout_config() |
| | oc.optim = VAESEVIRPLModule.get_optim_config() |
| | oc.logging = VAESEVIRPLModule.get_logging_config() |
| | oc.trainer = VAESEVIRPLModule.get_trainer_config() |
| | oc.eval = VAESEVIRPLModule.get_eval_config() |
| | oc.model = VAESEVIRPLModule.get_model_config() |
| | oc.dataset = VAESEVIRPLModule.get_dataset_config() |
| | if oc_from_file is not None: |
| | |
| | oc = OmegaConf.merge(oc, oc_from_file) |
| | return oc |
| |
|
| | @staticmethod |
| | def get_layout_config(): |
| | cfg = OmegaConf.create() |
| | cfg.img_height = 128 |
| | cfg.img_width = 128 |
| | cfg.layout = "NHWC" |
| | return cfg |
| |
|
| | @staticmethod |
| | def get_model_config(): |
| | cfg = OmegaConf.create() |
| | cfg.data_channels = 4 |
| | |
| | cfg.down_block_types = ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'] |
| | cfg.in_channels = cfg.data_channels |
| | cfg.sample_size = 512 |
| | cfg.block_out_channels = [128, 256, 512, 512] |
| | cfg.act_fn = 'silu' |
| | cfg.latent_channels = 4 |
| | cfg.up_block_types = ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'] |
| | cfg.norm_num_groups = 32 |
| | cfg.layers_per_block = 2 |
| | cfg.out_channels = cfg.data_channels |
| |
|
| | cfg.loss = OmegaConf.create() |
| | cfg.loss.disc_start = 50001 |
| | cfg.loss.kl_weight = 1e-6 |
| | cfg.loss.disc_weight = 0.5 |
| | cfg.loss.perceptual_weight = 1.0 |
| | cfg.loss.disc_in_channels = cfg.data_channels |
| | return cfg |
| |
|
| | @staticmethod |
| | def get_dataset_config(): |
| | cfg = OmegaConf.create() |
| | cfg.dataset_name = "sevirlr" |
| | cfg.img_height = 128 |
| | cfg.img_width = 128 |
| | cfg.in_len = 0 |
| | cfg.out_len = 1 |
| | cfg.seq_len = 1 |
| | cfg.plot_stride = 1 |
| | cfg.interval_real_time = 10 |
| | cfg.sample_mode = "sequent" |
| | cfg.stride = cfg.out_len |
| | cfg.layout = "NTHWC" |
| | cfg.start_date = None |
| | cfg.train_val_split_date = (2019, 1, 1) |
| | cfg.train_test_split_date = (2019, 6, 1) |
| | cfg.end_date = None |
| | cfg.metrics_mode = "0" |
| | cfg.metrics_list = ('csi', 'pod', 'sucr', 'bias') |
| | cfg.threshold_list = (16, 74, 133, 160, 181, 219) |
| | cfg.aug_mode = "1" |
| | return cfg |
| |
|
| | @staticmethod |
| | def get_optim_config(): |
| | cfg = OmegaConf.create() |
| | cfg.seed = None |
| | cfg.total_batch_size = 32 |
| | cfg.micro_batch_size = 8 |
| | cfg.float32_matmul_precision = "high" |
| |
|
| | cfg.method = "adam" |
| | cfg.lr = 1E-3 |
| | cfg.wd = 1E-5 |
| | cfg.betas = (0.5, 0.9) |
| | cfg.gradient_clip_val = 1.0 |
| | cfg.max_epochs = 50 |
| | |
| | cfg.warmup_percentage = 0.2 |
| | cfg.lr_scheduler_mode = "cosine" |
| | cfg.min_lr_ratio = 1.0E-3 |
| | cfg.warmup_min_lr_ratio = 0.0 |
| | |
| | cfg.monitor = "val/total_loss" |
| | cfg.early_stop = False |
| | cfg.early_stop_mode = "min" |
| | cfg.early_stop_patience = 5 |
| | cfg.save_top_k = 1 |
| | return cfg |
| |
|
| | @staticmethod |
| | def get_logging_config(): |
| | cfg = OmegaConf.create() |
| | cfg.logging_prefix = "SEVIRLR" |
| | cfg.monitor_lr = True |
| | cfg.monitor_device = False |
| | cfg.track_grad_norm = -1 |
| | cfg.use_wandb = False |
| | return cfg |
| |
|
| | @staticmethod |
| | def get_trainer_config(): |
| | cfg = OmegaConf.create() |
| | cfg.check_val_every_n_epoch = 1 |
| | cfg.log_step_ratio = 0.001 |
| | cfg.precision = 32 |
| | cfg.find_unused_parameters = True |
| | cfg.num_sanity_val_steps = 2 |
| | return cfg |
| |
|
| | @staticmethod |
| | def get_eval_config(): |
| | cfg = OmegaConf.create() |
| | cfg.train_example_data_idx_list = [0, ] |
| | cfg.val_example_data_idx_list = [0, ] |
| | cfg.test_example_data_idx_list = [0, ] |
| | cfg.eval_example_only = False |
| | cfg.num_vis = 10 |
| | return cfg |
| | |
| |
|
| | |
| | def configure_optimizers(self): |
| | optim_cfg = self.oc.optim |
| | lr = optim_cfg.lr |
| | betas = optim_cfg.betas |
| | opt_ae = torch.optim.Adam( |
| | list(self.torch_nn_module.encoder.parameters()) + |
| | list(self.torch_nn_module.decoder.parameters()) + |
| | list(self.torch_nn_module.quant_conv.parameters()) + |
| | list(self.torch_nn_module.post_quant_conv.parameters()), |
| | lr=lr, betas=betas |
| | ) |
| | opt_disc = torch.optim.Adam( |
| | self.loss.discriminator.parameters(), |
| | lr=lr, betas=betas |
| | ) |
| |
|
| | warmup_iter = int(np.round(optim_cfg.warmup_percentage * self.total_num_steps)) |
| | |
| | if optim_cfg.lr_scheduler_mode == 'none': |
| | return [{"optimizer": opt_ae}, {"optimizer": opt_disc}] |
| | else: |
| | if optim_cfg.lr_scheduler_mode == 'cosine': |
| | |
| | warmup_scheduler_ae = LambdaLR( |
| | opt_ae, |
| | lr_lambda=warmup_lambda(warmup_steps=warmup_iter, |
| | min_lr_ratio=optim_cfg.warmup_min_lr_ratio)) |
| | cosine_scheduler_ae = CosineAnnealingLR( |
| | opt_ae, |
| | T_max=(self.total_num_steps - warmup_iter), |
| | eta_min=optim_cfg.min_lr_ratio * optim_cfg.lr) |
| | lr_scheduler_ae = SequentialLR( |
| | opt_ae, |
| | schedulers=[warmup_scheduler_ae, cosine_scheduler_ae], |
| | milestones=[warmup_iter]) |
| | lr_scheduler_config_ae = { |
| | 'scheduler': lr_scheduler_ae, |
| | 'interval': 'step', |
| | 'frequency': 1, } |
| | |
| | warmup_scheduler_disc = LambdaLR( |
| | opt_disc, |
| | lr_lambda=warmup_lambda(warmup_steps=warmup_iter, |
| | min_lr_ratio=optim_cfg.warmup_min_lr_ratio)) |
| | cosine_scheduler_disc = CosineAnnealingLR( |
| | opt_disc, |
| | T_max=(self.total_num_steps - warmup_iter), |
| | eta_min=optim_cfg.min_lr_ratio * optim_cfg.lr) |
| | lr_scheduler_disc = SequentialLR( |
| | opt_disc, |
| | schedulers=[warmup_scheduler_disc, cosine_scheduler_disc], |
| | milestones=[warmup_iter]) |
| | lr_scheduler_config_disc = { |
| | 'scheduler': lr_scheduler_disc, |
| | 'interval': 'step', |
| | 'frequency': 1, } |
| | else: |
| | raise NotImplementedError |
| | return [ |
| | {"optimizer": opt_ae, "lr_scheduler": lr_scheduler_config_ae}, |
| | {"optimizer": opt_disc, "lr_scheduler": lr_scheduler_config_disc}, |
| | ] |
| |
|
| | def set_trainer_kwargs(self, **kwargs): |
| | r""" |
| | Default kwargs used when initializing pl.Trainer |
| | """ |
| | checkpoint_callback = ModelCheckpoint( |
| | monitor=self.oc.optim.monitor, |
| | dirpath=os.path.join(self.save_dir, "checkpoints"), |
| | filename="{epoch:03d}", |
| | auto_insert_metric_name=False, |
| | save_top_k=self.oc.optim.save_top_k, |
| | save_last=True, |
| | mode="min", |
| | ) |
| | callbacks = kwargs.pop("callbacks", []) |
| | assert isinstance(callbacks, list) |
| | for ele in callbacks: |
| | assert isinstance(ele, Callback) |
| | callbacks += [checkpoint_callback, ] |
| | if self.oc.logging.monitor_lr: |
| | callbacks += [LearningRateMonitor(logging_interval='step'), ] |
| | if self.oc.logging.monitor_device: |
| | callbacks += [DeviceStatsMonitor(), ] |
| | if self.oc.optim.early_stop: |
| | callbacks += [EarlyStopping(monitor=self.oc.optim.monitor, |
| | min_delta=0.0, |
| | patience=self.oc.optim.early_stop_patience, |
| | verbose=False, |
| | mode=self.oc.optim.early_stop_mode), ] |
| |
|
| | logger = kwargs.pop("logger", []) |
| | tb_logger = pl_loggers.TensorBoardLogger(save_dir=self.save_dir) |
| | csv_logger = pl_loggers.CSVLogger(save_dir=self.save_dir) |
| | logger += [tb_logger, csv_logger] |
| | if self.oc.logging.use_wandb: |
| | wandb_logger = pl_loggers.WandbLogger( |
| | name = self.oc.logging.logging_name, |
| | project=self.oc.logging.logging_prefix, |
| | save_dir=self.save_dir |
| | ) |
| | logger += [wandb_logger, ] |
| |
|
| | log_every_n_steps = max(1, int(self.oc.trainer.log_step_ratio * self.total_num_steps)) |
| | trainer_init_keys = inspect.signature(Trainer).parameters.keys() |
| | ret = dict( |
| | num_sanity_val_steps=self.oc.trainer.num_sanity_val_steps, |
| | callbacks=callbacks, |
| | |
| | logger=logger, |
| | log_every_n_steps=log_every_n_steps, |
| | |
| | default_root_dir=self.save_dir, |
| | |
| | accelerator="gpu", |
| | |
| | strategy=DDPStrategy(find_unused_parameters=self.oc.trainer.find_unused_parameters), |
| | |
| | max_epochs=self.oc.optim.max_epochs, |
| | check_val_every_n_epoch=self.oc.trainer.check_val_every_n_epoch, |
| | |
| | |
| | precision=self.oc.trainer.precision, |
| | ) |
| | oc_trainer_kwargs = OmegaConf.to_object(self.oc.trainer) |
| | oc_trainer_kwargs = {key: val for key, val in oc_trainer_kwargs.items() if key in trainer_init_keys} |
| | ret.update(oc_trainer_kwargs) |
| | ret.update(kwargs) |
| | return ret |
| | |
| | |
| | |
| | @classmethod |
| | def get_total_num_steps( |
| | cls, |
| | num_samples: int, |
| | total_batch_size: int, |
| | epoch: int = None): |
| | r""" |
| | Parameters |
| | ---------- |
| | num_samples: int |
| | The number of samples of the datasets. `num_samples / micro_batch_size` is the number of steps per epoch. |
| | total_batch_size: int |
| | `total_batch_size == micro_batch_size * world_size * grad_accum` |
| | """ |
| | if epoch is None: |
| | epoch = cls.get_optim_config().max_epochs |
| | return int(epoch * num_samples / total_batch_size) |
| |
|
| | @staticmethod |
| | def get_sevir_datamodule(dataset_cfg, |
| | micro_batch_size: int = 1, |
| | num_workers: int = 8): |
| | dm = SEVIRLightningDataModule( |
| | seq_len=dataset_cfg["seq_len"], |
| | sample_mode=dataset_cfg["sample_mode"], |
| | stride=dataset_cfg["stride"], |
| | batch_size=micro_batch_size, |
| | layout=dataset_cfg["layout"], |
| | output_type=np.float32, |
| | preprocess=True, |
| | rescale_method="01", |
| | verbose=False, |
| | aug_mode=dataset_cfg["aug_mode"], |
| | ret_contiguous=False, |
| | |
| | dataset_name=dataset_cfg["dataset_name"], |
| | start_date=dataset_cfg["start_date"], |
| | train_test_split_date=dataset_cfg["train_test_split_date"], |
| | end_date=dataset_cfg["end_date"], |
| | val_ratio=dataset_cfg["val_ratio"], |
| | num_workers=num_workers, ) |
| | return dm |
| |
|
| | def get_last_layer(self): |
| | return self.torch_nn_module.decoder.conv_out.weight |
| |
|
| | def get_input(self, batch): |
| | target_bchw = rearrange(batch, "b 1 h w c -> b c h w").contiguous() |
| | mask = None |
| | return target_bchw, mask |
| | |
| | |
| | def forward(self, target_bchw, sample_posterior=True): |
| | pred_bchw, posterior = self.torch_nn_module( |
| | sample=target_bchw, |
| | sample_posterior=sample_posterior, |
| | return_posterior=True) |
| | return pred_bchw, posterior |
| | def training_step(self, batch, batch_idx): |
| | g_opt, d_opt = self.optimizers() |
| | g_sch, d_sch = self.lr_schedulers() |
| |
|
| | target_bchw, _ = self.get_input(batch=batch) |
| | pred_bchw, posterior = self(target_bchw) |
| | micro_batch_size = batch.shape[self.batch_axis] |
| | data_idx = int(batch_idx * micro_batch_size) |
| | if self.current_epoch % self.oc.trainer.check_val_every_n_epoch == 0 \ |
| | and self.local_rank == 0: |
| | self.save_vis_step_end( |
| | data_idx=data_idx, |
| | target=target_bchw.detach().float().cpu().numpy(), |
| | pred=pred_bchw.detach().float().cpu().numpy(), |
| | mode="train" |
| | ) |
| |
|
| | |
| | aeloss, log_dict_ae = self.loss( |
| | target_bchw, pred_bchw, posterior, optimizer_idx=0, global_step=self.global_step, |
| | mask=None, last_layer=self.get_last_layer(), split="train" |
| | ) |
| | aeloss /= self.accumulate_grad_batches |
| | self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False) |
| | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=False) |
| |
|
| | self.manual_backward(aeloss) |
| | if (batch_idx + 1) % self.accumulate_grad_batches == 0: |
| | self.clip_gradients(g_opt, gradient_clip_val=self.oc.optim.gradient_clip_val, gradient_clip_algorithm="norm") |
| | g_opt.step() |
| | g_sch.step() |
| | g_opt.zero_grad() |
| |
|
| | |
| | discloss, log_dict_disc = self.loss( |
| | target_bchw, pred_bchw, posterior, optimizer_idx=1, global_step=self.global_step, |
| | mask=None, last_layer=self.get_last_layer(), split="train" |
| | ) |
| | discloss /= self.accumulate_grad_batches |
| | self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False) |
| | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=False) |
| |
|
| | self.manual_backward(discloss) |
| | if (batch_idx + 1) % self.accumulate_grad_batches == 0: |
| | self.clip_gradients(d_opt, gradient_clip_val=self.oc.optim.gradient_clip_val, gradient_clip_algorithm="norm") |
| | d_opt.step() |
| | d_sch.step() |
| | d_opt.zero_grad() |
| |
|
| | def validation_step(self, batch, batch_idx, dataloader_idx=0): |
| | micro_batch_size = batch.shape[self.batch_axis] |
| | data_idx = int(batch_idx * micro_batch_size) |
| | if not self.eval_example_only or data_idx in self.val_example_data_idx_list: |
| | target_bchw, _ = self.get_input(batch=batch) |
| | pred_bchw, posterior = self(target_bchw) |
| | target_bchw = target_bchw.contiguous() |
| | pred_bchw = pred_bchw.contiguous() |
| | if self.local_rank == 0: |
| | self.save_vis_step_end( |
| | data_idx=data_idx, |
| | target=target_bchw.detach().float().cpu().numpy(), |
| | pred=pred_bchw.detach().float().cpu().numpy(), |
| | mode="val", ) |
| | aeloss, log_dict_ae = self.loss(target_bchw, pred_bchw, posterior, 0, self.global_step, |
| | mask=None, last_layer=self.get_last_layer(), split="val") |
| | discloss, log_dict_disc = self.loss(target_bchw, pred_bchw, posterior, 1, self.global_step, |
| | mask=None, last_layer=self.get_last_layer(), split="val") |
| | |
| | |
| | self.log("val/rec_loss", log_dict_ae["val/rec_loss"], prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.valid_mse(pred_bchw, target_bchw) |
| | self.valid_mae(pred_bchw, target_bchw) |
| | def on_validation_epoch_end(self): |
| | |
| | |
| | valid_mse = self.valid_mse.compute() |
| | valid_mae = self.valid_mae.compute() |
| | |
| | self.log('valid_mse_epoch', valid_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.log('valid_mae_epoch', valid_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.valid_mse.reset() |
| | self.valid_mae.reset() |
| | |
| | def test_step(self, batch, batch_idx, dataloader_idx=0): |
| | micro_batch_size = batch.shape[self.batch_axis] |
| | data_idx = int(batch_idx * micro_batch_size) |
| | if not self.eval_example_only or data_idx in self.test_example_data_idx_list: |
| | target_bchw, _ = self.get_input(batch=batch) |
| | pred_bchw, posterior = self(target_bchw) |
| | target_bchw = target_bchw.contiguous() |
| | pred_bchw = pred_bchw.contiguous() |
| | if self.local_rank == 0: |
| | self.save_vis_step_end( |
| | data_idx=data_idx, |
| | target=target_bchw.detach().float().cpu().numpy(), |
| | pred=pred_bchw.detach().float().cpu().numpy(), |
| | mode="test", ) |
| | aeloss, log_dict_ae = self.loss(target_bchw, pred_bchw, posterior, 0, self.global_step, |
| | mask=None, last_layer=self.get_last_layer(), split="test") |
| | discloss, log_dict_disc = self.loss(target_bchw, pred_bchw, posterior, 1, self.global_step, |
| | mask=None, last_layer=self.get_last_layer(), split="test") |
| | self.log("test/rec_loss", log_dict_ae["test/rec_loss"], |
| | prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.test_mse(pred_bchw, target_bchw) |
| | self.test_mae(pred_bchw, target_bchw) |
| | def on_test_epoch_end(self): |
| | test_mse = self.test_mse.compute() |
| | test_mae = self.test_mae.compute() |
| |
|
| | self.log('test_mse_epoch', test_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.log('test_mae_epoch', test_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) |
| | self.test_mse.reset() |
| | self.test_mae.reset() |
| |
|
| | def save_vis_step_end( |
| | self, |
| | data_idx: int, |
| | target: np.ndarray, |
| | pred: np.ndarray, |
| | mode: str = "train", |
| | prefix: str = ""): |
| | r""" |
| | Parameters |
| | ---------- |
| | data_idx |
| | target, pred: np.ndarray |
| | Shape = (N, C, H, W), actually (T, 1, H, W) |
| | mode: str |
| | prefix: str |
| | """ |
| | if self.local_rank == 0: |
| | if mode == "train": |
| | example_data_idx_list = self.train_example_data_idx_list |
| | elif mode == "val": |
| | example_data_idx_list = self.val_example_data_idx_list |
| | elif mode == "test": |
| | example_data_idx_list = self.test_example_data_idx_list |
| | else: |
| | raise ValueError(f"Wrong mode {mode}! Must be in ['train', 'val', 'test'].") |
| | if data_idx in example_data_idx_list: |
| | save_name = f"{prefix}{mode}_epoch_{self.current_epoch}_data_{data_idx}.png" |
| | num_vis = min(target.shape[0], self.oc.eval.num_vis) |
| | seq_list = [ |
| | target[:num_vis].squeeze(1), |
| | pred[:num_vis].squeeze(1), |
| | ] |
| | label_list = [ |
| | "Target", |
| | f"{self.oc.logging.logging_prefix}", |
| | ] |
| | vis_sevir_seq( |
| | save_path=os.path.join(self.example_save_dir, save_name), |
| | seq=seq_list, |
| | label=label_list, |
| | plot_stride=1, fs=20, label_rotation=90) |
| |
|
| | def on_before_optimizer_step(self, optimizer): |
| | |
| | |
| | |
| | if self.oc.logging.track_grad_norm != -1: |
| | norms = grad_norm(self.torch_nn_module, norm_type=self.oc.logging.track_grad_norm) |
| | self.log_dict(norms) |
| |
|