| | import os |
| | import torch |
| | import numpy as np |
| | import lpips as lp |
| | import pandas as pd |
| | import torchmetrics |
| | import matplotlib.pyplot as plt |
| | from bisect import bisect_right |
| | import torchvision.transforms as T |
| | from torch import nn |
| |
|
| | from matplotlib.colors import ListedColormap, BoundaryNorm |
| | from matplotlib.lines import Line2D |
| |
|
| | from data import dutils |
| |
|
| | |
| | |
| | |
| |
|
| | class SequentialLR(torch.optim.lr_scheduler._LRScheduler): |
| | """Receives the list of schedulers that is expected to be called sequentially during |
| | optimization process and milestone points that provides exact intervals to reflect |
| | which scheduler is supposed to be called at a given epoch. |
| | |
| | Args: |
| | schedulers (list): List of chained schedulers. |
| | milestones (list): List of integers that reflects milestone points. |
| | |
| | Example: |
| | >>> # Assuming optimizer uses lr = 1. for all groups |
| | >>> # lr = 0.1 if epoch == 0 |
| | >>> # lr = 0.1 if epoch == 1 |
| | >>> # lr = 0.9 if epoch == 2 |
| | >>> # lr = 0.81 if epoch == 3 |
| | >>> # lr = 0.729 if epoch == 4 |
| | >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) |
| | >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) |
| | >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) |
| | >>> for epoch in range(100): |
| | >>> train(...) |
| | >>> validate(...) |
| | >>> scheduler.step() |
| | """ |
| |
|
| | def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): |
| | for scheduler_idx in range(1, len(schedulers)): |
| | if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): |
| | raise ValueError( |
| | "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " |
| | "got schedulers at index {} and {} to be different".format(0, scheduler_idx) |
| | ) |
| | if (len(milestones) != len(schedulers) - 1): |
| | raise ValueError( |
| | "Sequential Schedulers expects number of schedulers provided to be one more " |
| | "than the number of milestone points, but got number of schedulers {} and the " |
| | "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) |
| | ) |
| | self.optimizer = optimizer |
| | self._schedulers = schedulers |
| | self._milestones = milestones |
| | self.last_epoch = last_epoch + 1 |
| |
|
| | def step(self, ref=None): |
| | self.last_epoch += 1 |
| | idx = bisect_right(self._milestones, self.last_epoch) |
| | if idx > 0 and self._milestones[idx - 1] == self.last_epoch: |
| | self._schedulers[idx].step(0) |
| | else: |
| | |
| | if isinstance(self._schedulers[idx], torch.optim.lr_scheduler.ReduceLROnPlateau): |
| | self._schedulers[idx].step(ref) |
| | else: |
| | self._schedulers[idx].step() |
| |
|
| | def state_dict(self): |
| | """Returns the state of the scheduler as a :class:`dict`. |
| | |
| | It contains an entry for every variable in self.__dict__ which |
| | is not the optimizer. |
| | The wrapped scheduler states will also be saved. |
| | """ |
| | state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} |
| | state_dict['_schedulers'] = [None] * len(self._schedulers) |
| |
|
| | for idx, s in enumerate(self._schedulers): |
| | state_dict['_schedulers'][idx] = s.state_dict() |
| |
|
| | return state_dict |
| |
|
| | def load_state_dict(self, state_dict): |
| | """Loads the schedulers state. |
| | |
| | Args: |
| | state_dict (dict): scheduler state. Should be an object returned |
| | from a call to :meth:`state_dict`. |
| | """ |
| | _schedulers = state_dict.pop('_schedulers') |
| | self.__dict__.update(state_dict) |
| | |
| | |
| | state_dict['_schedulers'] = _schedulers |
| |
|
| | for idx, s in enumerate(_schedulers): |
| | self._schedulers[idx].load_state_dict(s) |
| |
|
| | def warmup_lambda(warmup_steps, min_lr_ratio=0.1): |
| | def ret_lambda(epoch): |
| | if epoch <= warmup_steps: |
| | return min_lr_ratio + (1.0 - min_lr_ratio) * epoch / warmup_steps |
| | else: |
| | return 1.0 |
| | return ret_lambda |
| |
|
| | |
| | |
| | |
| | def to_cpu_tensor(*args): |
| | ''' |
| | Input arbitrary number of array/tensors, each will be converted to CPU torch.Tensor |
| | ''' |
| | out = [] |
| | for tensor in args: |
| | if type(tensor) is np.ndarray: |
| | tensor = torch.Tensor(tensor) |
| | if type(tensor) is torch.Tensor: |
| | tensor = tensor.cpu() |
| | out.append(tensor) |
| | |
| | if len(out) == 1: |
| | return out[0] |
| | return out |
| |
|
| | def merge_leading_dims(tensor, n=2): |
| | ''' |
| | Merge the first N dimension of a tensor |
| | ''' |
| | return tensor.reshape((-1, *tensor.shape[n:])) |
| |
|
| | |
| | |
| | |
| | def build_model_name(model_type, model_config): |
| | ''' |
| | Build the model name (without extension) |
| | ''' |
| | model_name = model_type + '_' |
| | for k, v in model_config.items(): |
| | model_name += k |
| | if type(v) is list or type(v) is tuple: |
| | model_name += '-' |
| | for i, item in enumerate(v): |
| | model_name += (str(item) if type(item) is not bool else '') + ('-' if i < len(v)-1 else '') |
| | else: |
| | model_name += (('-' + str(v)) if type(v) is not bool else '') |
| | model_name += '_' |
| | return model_name[:-1] |
| |
|
| | def build_model_path(base_dir, dataset_type, model_type, timestamp=None): |
| | if timestamp is None: |
| | return os.path.join(base_dir, dataset_type, model_type) |
| | elif timestamp == True: |
| | return os.path.join(base_dir, dataset_type, model_type, pd.Timestamp.now().strftime('%Y%m%d%H%M%S')) |
| | return os.path.join(base_dir, dataset_type, model_type, timestamp) |
| |
|
| | |
| | |
| | |
| |
|
| | def hko7_preprocess(x_seq, x_mask, dt_clip, args): |
| | resize = args.resize if 'resize' in args else x_seq.shape[-1] |
| | seq_len = args.seq_len if 'seq_len' in args else 5 |
| |
|
| | |
| | x_seq = x_seq.transpose((1, 0, 2, 3, 4)) / 255. |
| | if 'scale' in args and args.scale == 'non-linear': |
| | x_seq = dutils.linear_to_nonlinear_batched(x_seq, dt_clip) |
| | else: |
| | x_seq = dutils.nonlinear_to_linear_batched(x_seq, dt_clip) |
| |
|
| | b, t, c, h, w = x_seq.shape |
| | assert c == 1, f'# channels ({c}) != 1' |
| |
|
| | |
| | x_seq = torch.Tensor(x_seq).float().reshape((b*t, c, h, w)) |
| | if resize != h: |
| | tform = T.Compose([ |
| | T.ToPILImage(), |
| | T.Resize(resize), |
| | T.ToTensor(), |
| | ]) |
| | else: |
| | tform = T.Compose([]) |
| |
|
| | x_seq = torch.stack([tform(x_frame) for x_frame in x_seq], dim=0) |
| | x_seq = x_seq.reshape((b, t, c, resize, resize)) |
| |
|
| | x, y = x_seq[:, :seq_len], x_seq[:, seq_len:] |
| | return x, y |
| |
|
| | |
| | |
| | |
| |
|
| | mae = lambda *args: torch.nn.functional.l1_loss(*args).cpu().detach().numpy() |
| | mse = lambda *args: torch.nn.functional.mse_loss(*args).cpu().detach().numpy() |
| |
|
| | def ssim(y_pred, y): |
| | y, y_pred = to_cpu_tensor(y, y_pred) |
| | b, t, c, h, w = y.shape |
| | y = y.reshape((b*t, c, h, w)) |
| | y_pred = y_pred.reshape((b*t, c, h, w)) |
| | |
| | y = torch.clamp(y, 0, 1) |
| | y_pred = torch.clamp(y_pred, 0, 1) |
| | return torchmetrics.image.StructuralSimilarityIndexMeasure(data_range=1.0)(y_pred, y) |
| |
|
| | def psnr(y_pred, y): |
| | y, y_pred = to_cpu_tensor(y, y_pred) |
| | b, t, c, h, w = y.shape |
| | y = y.reshape((b*t, c, h, w)) |
| | y_pred = y_pred.reshape((b*t, c, h, w)) |
| | acc_score = 0 |
| | for i in range(b*t): |
| | acc_score += torchmetrics.image.PeakSignalNoiseRatio(data_range=1.0)(y_pred[i], y[i]) / (b*t) |
| | return acc_score |
| |
|
| | GLOBAL_LPIPS_OBJ = None |
| | def lpips64(y_pred, y, net='vgg'): |
| | |
| | y = merge_leading_dims(y) |
| | y_pred = merge_leading_dims(y_pred) |
| |
|
| | y = torch.nn.functional.interpolate(y, (64, 64), mode='bicubic').clamp(0,1) |
| | y_pred = torch.nn.functional.interpolate(y_pred, (64, 64), mode='bicubic').clamp(0,1) |
| | |
| | y = (2 * y - 1) |
| | y_pred = (2 * y_pred - 1) |
| | global GLOBAL_LPIPS_OBJ |
| | if GLOBAL_LPIPS_OBJ is None: |
| | GLOBAL_LPIPS_OBJ = lp.LPIPS(net=net).to(y.device) |
| | return GLOBAL_LPIPS_OBJ(y_pred, y).mean() |
| |
|
| | def tfpn(y_pred, y, threshold, radius=1): |
| | ''' |
| | convert to cpu, and merge the first two dimensions |
| | ''' |
| | y = merge_leading_dims(y) |
| | y_pred = merge_leading_dims(y_pred) |
| | with torch.no_grad(): |
| | if radius > 1: |
| | pool = nn.MaxPool2d(radius) |
| | y = pool(y) |
| | y_pred = pool(y_pred) |
| | y = torch.where(y >= threshold, 1, 0) |
| | y_pred = torch.where(y_pred >= threshold, 1, 0) |
| | mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold) |
| | (tn, fp), (fn, tp) = to_cpu_tensor(mat) |
| | return tp, tn, fp, fn |
| |
|
| | def tfpn_pool(y_pred, y, threshold, radius): |
| | y_pred = merge_leading_dims(y_pred) |
| | y = merge_leading_dims(y) |
| | pool = nn.MaxPool2d(radius, stride=radius//4 if radius//4 > 0 else radius) |
| | with torch.no_grad(): |
| | y = torch.where(y>=threshold, 1, 0).float() |
| | y_pred = torch.where(y_pred>=threshold, 1, 0).float() |
| | y = pool(y) |
| | y_pred = pool(y_pred) |
| | mat = torchmetrics.functional.confusion_matrix(y_pred, y, task='binary', threshold=threshold) |
| | (tn, fp), (fn, tp) = to_cpu_tensor(mat) |
| | return tp, tn, fp, fn |
| |
|
| | def csi(tp, tn, fp, fn): |
| | '''Critical Success Index. The larger the better.''' |
| | if (tp + fn + fp) < 1e-7: |
| | return 0. |
| | return tp / (tp + fn + fp) |
| |
|
| | def hss(tp, tn, fp, fn): |
| | '''Heidke Skill Score. (-inf, 1]. Larger better.''' |
| | if (tp+fn)*(fn+tn) + (tp+fp)*(fp+tn) == 0: |
| | return 0. |
| | return 2 * (tp*tn - fp*fn) / ((tp+fn)*(fn+tn) + (tp+fp)*(fp+tn)) |
| |
|
| | |
| | |
| | |
| |
|
| | def torch_visualize(sequences, savedir=None, horizontal=10, vmin=0, vmax=1): |
| | ''' |
| | input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W) |
| | C is assumed to be 1 and squeezed |
| | If batch > 1, only the first sequence will be printed |
| | ''' |
| | |
| | vertical = 0 |
| | display_texts = [] |
| | if (type(sequences) is dict): |
| | temp = [] |
| | for k, v in sequences.items(): |
| | vertical += int(np.ceil(v.shape[1] / horizontal)) |
| | temp.append(v) |
| | display_texts.append(k) |
| | sequences = temp |
| | else: |
| | for i, sequence in enumerate(sequences): |
| | vertical += int(np.ceil(sequence.shape[1] / horizontal)) |
| | display_texts.append(f'Item {i+1}') |
| | sequences = to_cpu_tensor(*sequences) |
| | |
| | j = 0 |
| | fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True) |
| | plt.setp(axes, xticks=[], yticks=[]) |
| | for k, sequence in enumerate(sequences): |
| | |
| | sequence = sequence[0].squeeze() |
| | axes[j, 0].set_ylabel(display_texts[k]) |
| | for i, frame in enumerate(sequence): |
| | j_shift = j + i // horizontal |
| | i_shift = i % horizontal |
| | axes[j_shift, i_shift].imshow(frame, vmin=vmin, vmax=vmax, cmap='gray') |
| | j += int(np.ceil(sequence.shape[0] / horizontal)) |
| | if savedir: |
| | plt.savefig(savedir + '' if savedir.endswith('.png') else '.png') |
| | plt.close() |
| | else: |
| | plt.show() |
| |
|
| | """ Visualize function with colorbar and a line seprate input and output """ |
| | def color_visualize(sequences, savedir='', horizontal=5, skip=1, ypos=0): |
| | ''' |
| | input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W) |
| | C is assumed to be 1 and squeezed |
| | If batch > 1, only the first sequence will be printed |
| | ''' |
| | plt.style.use(['science', 'no-latex']) |
| | VIL_COLORS = [[0, 0, 0], |
| | [0.30196078431372547, 0.30196078431372547, 0.30196078431372547], |
| | [0.1568627450980392, 0.7450980392156863, 0.1568627450980392], |
| | [0.09803921568627451, 0.5882352941176471, 0.09803921568627451], |
| | [0.0392156862745098, 0.4117647058823529, 0.0392156862745098], |
| | [0.0392156862745098, 0.29411764705882354, 0.0392156862745098], |
| | [0.9607843137254902, 0.9607843137254902, 0.0], |
| | [0.9294117647058824, 0.6745098039215687, 0.0], |
| | [0.9411764705882353, 0.43137254901960786, 0.0], |
| | [0.6274509803921569, 0.0, 0.0], |
| | [0.9058823529411765, 0.0, 1.0]] |
| |
|
| | VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0] |
| |
|
| | |
| | vertical = 0 |
| | display_texts = [] |
| | if (type(sequences) is dict): |
| | temp = [] |
| | for k, v in sequences.items(): |
| | vertical += int(np.ceil(v.shape[1] / horizontal)) |
| | temp.append(v) |
| | display_texts.append(k) |
| | sequences = temp |
| | else: |
| | for i, sequence in enumerate(sequences): |
| | vertical += int(np.ceil(sequence.shape[1] / horizontal)) |
| | display_texts.append(f'Item {i+1}') |
| | sequences = to_cpu_tensor(*sequences) |
| | |
| | j = 0 |
| | fig, axes = plt.subplots(vertical, horizontal, figsize=(2*horizontal, 2*vertical), tight_layout=True) |
| | plt.subplots_adjust(hspace=0.0, wspace=0.0) |
| | plt.setp(axes, xticks=[], yticks=[]) |
| | for k, sequence in enumerate(sequences): |
| | |
| | sequence = sequence[0].squeeze() |
| | |
| | |
| | |
| | if k == 0: |
| | for i in range(len(sequence)): |
| | axes[j, i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=16) |
| | axes[j, i].xaxis.set_label_position('top') |
| | elif k == len(sequences)-1: |
| | for i in range(len(sequence)): |
| | axes[j, i].set_xlabel(f'$t+{skip*i+1}$', fontsize=16) |
| | axes[j, i].xaxis.set_label_position('bottom') |
| | |
| | axes[j, 0].set_ylabel(display_texts[k], fontsize=16) |
| | for i, frame in enumerate(sequence): |
| | j_shift = j + i // horizontal |
| | i_shift = i % horizontal |
| | im = axes[j_shift, i_shift].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
| | norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N)) |
| | j += int(np.ceil(sequence.shape[0] / horizontal)) |
| | |
| | |
| | if ypos == 0: |
| | ypos = 1 - 1 / len(sequences) - 0.017 |
| | fig.lines.append(Line2D((0, 1), (ypos, ypos), transform=fig.transFigure, ls='--', linewidth=2, color='#444')) |
| | |
| | cax = fig.add_axes([1, 0.05, 0.02, 0.5]) |
| | fig.colorbar(im, cax=cax) |
| | |
| | if savedir: |
| | plt.savefig(savedir + '' if len(savedir)>0 else 'out.png') |
| | plt.close() |
| | else: |
| | plt.show() |
| |
|