Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| import plotly.express as px | |
| from plotly.subplots import make_subplots | |
| from trainer import LitTrainer | |
| from models import CNN | |
| from dataset import DatasetMNIST, load_mnist | |
| def load_pl_net(path="../checkpoints/lightning_logs/version_26/checkpoints/epoch=9-step=1000.ckpt"): | |
| pl_net = LitTrainer.load_from_checkpoint(path, model=CNN(1, 10)) | |
| return pl_net | |
| def load_torch_net(path="../checkpoints/pytorch/version_0.pt"): | |
| net = torch.load(path) | |
| net.eval() | |
| return net | |
| def get_sequence(model): | |
| fig = make_subplots(rows=2, cols=5) | |
| i, j = 0, np.random.randint(0, 30000) | |
| while i < 10: | |
| x, y = dataset[j] | |
| y_pred = model(x.to("cuda")).detach().cpu() | |
| p = torch.max(nn.functional.softmax(y_pred, dim=0)) | |
| y_pred = int(np.argmax(y_pred)) | |
| if y_pred == i and p > 0.95: | |
| img = np.flip(np.array(x.reshape(28, 28)), 0) | |
| fig.add_trace(px.imshow(img).data[0], row=int(i/5)+1, col=i%5+1) | |
| i += 1 | |
| j += 1 | |
| return fig | |
| if __name__ == "__main__": | |
| mnist = load_mnist("../downloads/mnist/") | |
| dataset, test_data = DatasetMNIST(*mnist["train"]), DatasetMNIST(*mnist["test"]) | |
| print("PyTorch Lightning Network") | |
| get_sequence(load_pl_net().to("cuda")).write_image("images/pl_net.png") | |
| print("Manual Network") | |
| get_sequence(load_torch_net().to("cuda")).write_image("images/pytorch_net.png") | |