| """
|
| contains functions for training and testing a pytorch model
|
| """
|
| import torch
|
|
|
| from tqdm.auto import tqdm
|
| from typing import Dict, List, Tuple
|
|
|
|
|
| def train_step(model: torch.nn.Module,
|
| dataloader: torch.utils.data.DataLoader,
|
| loss_fn: torch.nn.Module,
|
| optimizer: torch.optim.Optimizer,
|
| device: torch.device) -> Tuple[float, float]:
|
| """Trains a pytorch model for a single epoch
|
|
|
| turns a target model to training mode then runs through all of the required training steps
|
| (forward pass, loss calculation, optimizer step).
|
|
|
| Args:
|
| model: pytorch model
|
| dataloader: dataloader insatnce for the model to be trained on
|
| loss_fn: pytorch loss function to calculate loss
|
| optimizer: pytorch optimizer to help minimize the loss function
|
| device: target device
|
|
|
| returns:
|
| a tuple of training loss and training accuracy metrics
|
| in the form (train_loss, train_accuracy)
|
| """
|
|
|
| model.train()
|
|
|
|
|
| train_loss, train_accuracy = 0, 0
|
|
|
|
|
| for batch, (X, y) in enumerate(dataloader):
|
|
|
| X, y = X.to(device), y.to(device)
|
|
|
|
|
| logits = model(X)
|
|
|
|
|
| loss = loss_fn(logits, y)
|
| train_loss += loss
|
|
|
|
|
| optimizer.zero_grad()
|
|
|
|
|
| loss.backward()
|
|
|
|
|
| optimizer.step()
|
|
|
|
|
| preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
| train_accuracy += (preds == y).sum().item()/len(preds)
|
|
|
|
|
| train_loss /= len(dataloader)
|
| train_accuracy /= len(dataloader)
|
| return train_loss, train_accuracy
|
|
|
| def test_step(model: torch.nn.Module,
|
| dataloader: torch.utils.data.DataLoader,
|
| loss_fn: torch.nn.Module,
|
| device: torch.device) -> Tuple[float, float]:
|
| """Tests a pytorch model for a single epoch
|
|
|
| Turns a target model to eval mode and then performs a forward pass on a testing
|
| dataset.
|
|
|
| Args:
|
| model: pytorch model
|
| dataloader: dataloader insatnce for the model to be tested on
|
| loss_fn: loss function to calculate loss (errors)
|
| device: target device to compute on
|
|
|
| returns:
|
| A tuple of testing loss and testing accuracy metrics.
|
| In the form (test_loss, test_accuracy)
|
| """
|
|
|
| model.eval()
|
|
|
|
|
| test_loss, test_accuracy = 0, 0
|
|
|
|
|
| with torch.inference_mode():
|
|
|
| for X, y in dataloader:
|
|
|
| X, y = X.to(device), y.to(device)
|
|
|
|
|
| logits = model(X)
|
|
|
|
|
| loss = loss_fn(logits, y)
|
| test_loss += loss.item()
|
|
|
|
|
| test_preds = torch.softmax(logits, dim=-1).argmax(dim=-1)
|
| test_accuracy += ((test_preds == y).sum().item()/len(test_preds))
|
|
|
| test_loss /= len(dataloader)
|
| test_accuracy /= len(dataloader)
|
| return test_loss, test_accuracy
|
|
|
| def train(model: torch.nn.Module,
|
| train_dataloader: torch.utils.data.DataLoader,
|
| test_dataloader: torch.utils.data.DataLoader,
|
| optimizer: torch.optim.Optimizer,
|
| loss_fn: torch.nn.Module,
|
| epochs: int,
|
| device: torch.device,
|
| writer: torch.utils.tensorboard.writer.SummaryWriter) -> Dict[str, List]:
|
| """Trains and tests pytorch model
|
|
|
| passes a target model through train_step() and test_step()
|
| functions for a number of epochs, training and testing the model in the same epoch loop.
|
|
|
| calculates, prints and stores evaluation metric throughout.
|
|
|
| Args:
|
| model: pytorch model
|
| train_dataloader: DataLoader instance for the model to be trained on
|
| test_dataloader: DataLoader instance for the model to be tested on
|
| optimizer: pytorch optimizer
|
| loss_fn: pytorch loss function
|
| epochs: integer indicating how many epochs to train for
|
| device: target device to compute on
|
|
|
| returns:
|
| A dictionaru of training and testing loss as well as training and testing accuracy
|
| metrics. Each metric has a value in a list for each epoch.
|
|
|
| In the form: {train_loss: [...],
|
| train_acc: [...],
|
| test_loss: [...],
|
| test_acc: [...]}
|
| """
|
|
|
| results = {
|
| "train_loss": [],
|
| "train_acc": [],
|
| "test_loss": [],
|
| "test_acc": []
|
| }
|
|
|
|
|
| for epoch in tqdm(range(epochs)):
|
| train_loss, train_acc = train_step(model=model,
|
| dataloader=train_dataloader,
|
| loss_fn=loss_fn,
|
| optimizer=optimizer,
|
| device=device)
|
| test_loss, test_acc = test_step(model=model,
|
| dataloader=test_dataloader,
|
| loss_fn=loss_fn,
|
| device=device)
|
|
|
| if epoch % 1 == 0:
|
| print(
|
| f"Epoch: {epoch+1} | "
|
| f"train_loss: {train_loss:.4f} | "
|
| f"train_acc: {train_acc:.4f} | "
|
| f"test_loss: {test_loss:.4f} | "
|
| f"test_acc: {test_acc:.4f}"
|
| )
|
|
|
|
|
| results["train_loss"].append(train_loss.item())
|
| results["train_acc"].append(train_acc)
|
| results["test_loss"].append(test_loss)
|
| results["test_acc"].append(test_acc)
|
|
|
| if writer:
|
|
|
|
|
| writer.add_scalars(main_tag="Loss", tag_scalar_dict={"train loss": train_loss, "test loss": test_loss}, global_step=epoch)
|
|
|
| writer.add_scalars(main_tag="Accuracy", tag_scalar_dict={"train acc": train_acc, "test acc": test_acc}, global_step=epoch)
|
|
|
| writer.add_graph(model=model, input_to_model=torch.randn(size=(32, 3, 224, 224)).to(device))
|
| writer.close()
|
|
|
|
|
|
|
| return results
|
|
|