| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import torch.nn.functional as F |
|
|
| from matplotlib.colors import ListedColormap, BoundaryNorm |
| from matplotlib.lines import Line2D |
| import matplotlib.animation as animation |
| import scienceplots |
|
|
| def resize(seq, size): |
| |
| seq = F.interpolate(seq.squeeze(dim=2), size=size, mode='bilinear', align_corners=False) |
| seq = seq.clamp(0,1) |
| return seq.unsqueeze(2) |
|
|
| |
| |
| |
| def to_cpu_tensor(*args): |
| ''' |
| Input arbitrary number of array/tensors, each will be converted to CPU torch.Tensor |
| ''' |
| out = [] |
| for tensor in args: |
| if type(tensor) is np.ndarray: |
| tensor = torch.Tensor(tensor) |
| if type(tensor) is torch.Tensor: |
| tensor = tensor.cpu() |
| out.append(tensor) |
| |
| if len(out) == 1: |
| return out[0] |
| return out |
|
|
| from tempfile import NamedTemporaryFile |
|
|
| plt.style.use(['science', 'no-latex']) |
| VIL_COLORS = [[0, 0, 0], |
| [0.30196078431372547, 0.30196078431372547, 0.30196078431372547], |
| [0.1568627450980392, 0.7450980392156863, 0.1568627450980392], |
| [0.09803921568627451, 0.5882352941176471, 0.09803921568627451], |
| [0.0392156862745098, 0.4117647058823529, 0.0392156862745098], |
| [0.0392156862745098, 0.29411764705882354, 0.0392156862745098], |
| [0.9607843137254902, 0.9607843137254902, 0.0], |
| [0.9294117647058824, 0.6745098039215687, 0.0], |
| [0.9411764705882353, 0.43137254901960786, 0.0], |
| [0.6274509803921569, 0.0, 0.0], |
| [0.9058823529411765, 0.0, 1.0]] |
|
|
| VIL_LEVELS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255.0] |
|
|
| """ Visualize function with colorbar and a line seprate input and output """ |
| def gradio_visualize(sequence): |
| ''' |
| input: sequences, a list/dict of numpy/torch arrays with shape (T, C, H, W) |
| C is assumed to be 1 and squeezed |
| If batch > 1, only the first sequence will be printed |
| ''' |
| |
| fig_size = 3 |
| fig, axes = plt.subplots(1, len(sequence), figsize=(fig_size*len(sequence), fig_size), tight_layout=True) |
| plt.subplots_adjust(hspace=0.0, wspace=0.0) |
| plt.setp(axes, xticks=[], yticks=[]) |
|
|
| for i, frame in enumerate(sequence): |
| axes[i].set_xticks([]) |
| axes[i].set_yticks([]) |
| axes[i].set_xlabel(f'$t-{(len(sequence)-i)-1}$', fontsize=12) |
| frame = frame.squeeze() |
| im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N)) |
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| cax = fig.add_axes([1, 0.05, 0.02, 0.5]) |
| fig.colorbar(im, cax=cax) |
|
|
| |
| with NamedTemporaryFile(suffix=".png", delete=False) as ff: |
| fig.savefig(ff.name) |
| file_path = ff.name |
| |
| |
| plt.close(fig) |
| |
| return file_path |
|
|
|
|
| def gradio_gif(sequences, T): |
| ''' |
| input: sequences, a list/dict of numpy/torch arrays with shape (B, T, C, H, W) |
| C is assumed to be 1 and squeezed |
| If batch > 1, only the first sequence will be printed |
| ''' |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| horizontal = len(sequences) |
| fig_size = 3 |
| fig, axes = plt.subplots(nrows=1, ncols=horizontal, figsize=(fig_size*horizontal, fig_size), tight_layout=True) |
| plt.subplots_adjust(hspace=0.0, wspace=0.0) |
| plt.setp(axes, xticks=[], yticks=[]) |
|
|
| if horizontal == 1: |
| for i, (key, sequence) in enumerate(sequences.items()): |
| axes.set_xticks([]) |
| axes.set_yticks([]) |
| axes.set_xlabel(f'{key}', fontsize=12) |
| frame = sequence[0].squeeze() |
| im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
| norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) |
| else: |
| for i, (key, sequence) in enumerate(sequences.items()): |
| axes[i].set_xticks([]) |
| axes[i].set_yticks([]) |
| axes[i].set_xlabel(f'{key}', fontsize=12) |
| frame = sequence[0].squeeze() |
| im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
| norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) |
|
|
| title = fig.suptitle('', y=0.9, x=0.505, fontsize=16) |
|
|
| |
|
|
| def animate(t): |
| if horizontal == 1: |
| for i, sequence in enumerate(sequences.values()): |
| frame = sequence[t].squeeze() |
| im = axes.imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
| norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) |
| else: |
| for i, sequence in enumerate(sequences.values()): |
| frame = sequence[t].squeeze() |
| im = axes[i].imshow(frame*255, cmap=ListedColormap(VIL_COLORS), \ |
| norm=BoundaryNorm(VIL_LEVELS, ListedColormap(VIL_COLORS).N), animated=True) |
| plt.subplots_adjust(hspace=0.0, wspace=0.0) |
| |
| title.set_text(f'$t + {t}$') |
|
|
| return fig, |
|
|
| ani = animation.FuncAnimation(fig, animate, frames=T, interval=750, blit=True, repeat_delay=50,) |
| |
| |
| with NamedTemporaryFile(suffix=".gif", delete=False) as ff: |
| ani.save(ff.name, writer='pillow', fps=5) |
| file_path = ff.name |
| |
| plt.close() |
| return file_path |