|
|
import argparse |
|
|
import os |
|
|
import random |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import seaborn as sns |
|
|
sns.set_style('darkgrid') |
|
|
import torch |
|
|
if torch.cuda.is_available(): |
|
|
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
import torch.nn as nn |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
project_root = str(Path(__file__).resolve().parents[2]) |
|
|
if project_root not in sys.path: |
|
|
sys.path.append(project_root) |
|
|
|
|
|
from data.custom_datasets import ImageNet |
|
|
from torchvision import datasets |
|
|
from torchvision import transforms |
|
|
from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES |
|
|
from models.ctm import ContinuousThoughtMachine |
|
|
from models.lstm import LSTMBaseline |
|
|
from models.ff import FFBaseline |
|
|
from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif |
|
|
from utils.housekeeping import set_seed, zip_python_code |
|
|
from utils.losses import image_classification_loss, EnergyContrastiveLoss |
|
|
from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup |
|
|
|
|
|
from accelerate import Accelerator |
|
|
from huggingface_hub import upload_folder |
|
|
|
|
|
from autoclip.torch import QuantileClip |
|
|
|
|
|
import gc |
|
|
import torchvision |
|
|
torchvision.disable_beta_transforms_warning() |
|
|
|
|
|
|
|
|
import warnings |
|
|
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable") |
|
|
warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning) |
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
"Corrupt EXIF data", |
|
|
UserWarning, |
|
|
r"^PIL\.TiffImagePlugin$" |
|
|
) |
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
"UserWarning: Metadata Warning", |
|
|
UserWarning, |
|
|
r"^PIL\.TiffImagePlugin$" |
|
|
) |
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
"UserWarning: Truncated File Read", |
|
|
UserWarning, |
|
|
r"^PIL\.TiffImagePlugin$" |
|
|
) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
|
|
|
parser.add_argument('--model', type=str, default='ctm', choices=['ctm', 'lstm', 'ff'], help='Model type to train.') |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.') |
|
|
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.') |
|
|
parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.') |
|
|
|
|
|
parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).') |
|
|
parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).') |
|
|
parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).') |
|
|
parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).', |
|
|
choices=['none', |
|
|
'learnable-fourier', |
|
|
'multi-learnable-fourier', |
|
|
'custom-rotational']) |
|
|
|
|
|
parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).') |
|
|
parser.add_argument('--n_synch_out', type=int, default=512, help='Number of neurons to use for output synch (CTM only).') |
|
|
parser.add_argument('--n_synch_action', type=int, default=512, help='Number of neurons to use for observation/action synch (CTM only).') |
|
|
parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).') |
|
|
parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).') |
|
|
parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).') |
|
|
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).') |
|
|
parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).') |
|
|
parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).') |
|
|
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).') |
|
|
|
|
|
|
|
|
parser.add_argument('--energy_head_enabled', action=argparse.BooleanOptionalAction, default=False, help='Enable energy head.') |
|
|
parser.add_argument('--energy_hidden_dim', type=int, default=64, help='Hidden dim for energy head.') |
|
|
parser.add_argument('--loss_type', type=str, default='standard', choices=['standard', 'energy_contrastive'], help='Loss type.') |
|
|
parser.add_argument('--energy_margin', type=float, default=10.0, help='Margin for energy loss.') |
|
|
parser.add_argument('--energy_scale', type=float, default=0.1, help='Scale for energy loss.') |
|
|
|
|
|
|
|
|
parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).') |
|
|
|
|
|
|
|
|
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training.') |
|
|
parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing.') |
|
|
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.') |
|
|
parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.') |
|
|
parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.') |
|
|
parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.') |
|
|
parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.') |
|
|
parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.') |
|
|
parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.') |
|
|
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.') |
|
|
parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start') |
|
|
parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).') |
|
|
parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components (backbone, synapses if CTM).') |
|
|
parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.') |
|
|
|
|
|
|
|
|
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.') |
|
|
parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.') |
|
|
parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.') |
|
|
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.') |
|
|
parser.add_argument('--seed', type=int, default=412, help='Random seed.') |
|
|
parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?') |
|
|
parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') |
|
|
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.') |
|
|
parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval') |
|
|
parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.') |
|
|
parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.') |
|
|
parser.add_argument('--reload', type=str, default=None, help='Path to checkpoint to reload from.') |
|
|
parser.add_argument('--wandb', action=argparse.BooleanOptionalAction, default=False, help='Log to WandB.') |
|
|
|
|
|
|
|
|
parser.add_argument('--push_to_hub', action=argparse.BooleanOptionalAction, default=False, help='Push model to HF Hub.') |
|
|
parser.add_argument('--hub_model_id', type=str, default=None, help='HF Hub model ID (e.g., username/repo).') |
|
|
parser.add_argument('--hub_token', type=str, default=None, help='HF Hub token.') |
|
|
parser.add_argument('--hub_private', action=argparse.BooleanOptionalAction, default=False, help='Make HF Hub repo private.') |
|
|
|
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
def get_dataset(dataset, root): |
|
|
if dataset=='imagenet': |
|
|
dataset_mean = [0.485, 0.456, 0.406] |
|
|
dataset_std = [0.229, 0.224, 0.225] |
|
|
|
|
|
normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std) |
|
|
train_transform = transforms.Compose([ |
|
|
transforms.RandomResizedCrop(224), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.ToTensor(), |
|
|
normalize]) |
|
|
test_transform = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
normalize]) |
|
|
|
|
|
class_labels = list(IMAGENET2012_CLASSES.values()) |
|
|
|
|
|
train_data = ImageNet(which_split='train', transform=train_transform) |
|
|
test_data = ImageNet(which_split='validation', transform=test_transform) |
|
|
elif dataset=='cifar10': |
|
|
dataset_mean = [0.49139968, 0.48215827, 0.44653124] |
|
|
dataset_std = [0.24703233, 0.24348505, 0.26158768] |
|
|
normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std) |
|
|
train_transform = transforms.Compose( |
|
|
[transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10), |
|
|
transforms.ToTensor(), |
|
|
normalize, |
|
|
]) |
|
|
|
|
|
test_transform = transforms.Compose( |
|
|
[transforms.ToTensor(), |
|
|
normalize, |
|
|
]) |
|
|
train_data = datasets.CIFAR10(root, train=True, transform=train_transform, download=True) |
|
|
test_data = datasets.CIFAR10(root, train=False, transform=test_transform, download=True) |
|
|
class_labels = ['air', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] |
|
|
elif dataset=='cifar100': |
|
|
dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344] |
|
|
dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393] |
|
|
normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std) |
|
|
|
|
|
train_transform = transforms.Compose( |
|
|
[transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10), |
|
|
transforms.ToTensor(), |
|
|
normalize, |
|
|
]) |
|
|
test_transform = transforms.Compose( |
|
|
[transforms.ToTensor(), |
|
|
normalize, |
|
|
]) |
|
|
train_data = datasets.CIFAR100(root, train=True, transform=train_transform, download=True) |
|
|
test_data = datasets.CIFAR100(root, train=False, transform=test_transform, download=True) |
|
|
idx_order = np.argsort(np.array(list(train_data.class_to_idx.values()))) |
|
|
class_labels = list(np.array(list(train_data.class_to_idx.keys()))[idx_order]) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
return train_data, test_data, class_labels, dataset_mean, dataset_std |
|
|
|
|
|
|
|
|
|
|
|
if __name__=='__main__': |
|
|
|
|
|
|
|
|
args = parse_args() |
|
|
|
|
|
set_seed(args.seed) |
|
|
|
|
|
|
|
|
accelerator = Accelerator(log_with="wandb" if args.wandb else None) |
|
|
device = accelerator.device |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
if not os.path.exists(args.log_dir): |
|
|
os.makedirs(args.log_dir) |
|
|
print(f"Logging to {args.log_dir}") |
|
|
if args.wandb: |
|
|
accelerator.init_trackers( |
|
|
project_name="continuous-thought-machines", |
|
|
config=vars(args), |
|
|
init_kwargs={"wandb": {"name": args.log_dir.split('/')[-1]}} |
|
|
) |
|
|
|
|
|
assert args.dataset in ['cifar10', 'cifar100', 'imagenet'] |
|
|
|
|
|
|
|
|
train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root) |
|
|
|
|
|
num_workers_test = 1 |
|
|
trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train) |
|
|
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False) |
|
|
|
|
|
prediction_reshaper = [-1] |
|
|
args.out_dims = len(class_labels) |
|
|
|
|
|
|
|
|
zip_python_code(f'{args.log_dir}/repo_state.zip') |
|
|
with open(f'{args.log_dir}/args.txt', 'w') as f: |
|
|
print(args, file=f) |
|
|
|
|
|
|
|
|
print(f'Running model {args.model} on {device}') |
|
|
|
|
|
|
|
|
model = None |
|
|
if args.model == 'ctm': |
|
|
model = ContinuousThoughtMachine( |
|
|
iterations=args.iterations, |
|
|
d_model=args.d_model, |
|
|
d_input=args.d_input, |
|
|
heads=args.heads, |
|
|
n_synch_out=args.n_synch_out, |
|
|
n_synch_action=args.n_synch_action, |
|
|
synapse_depth=args.synapse_depth, |
|
|
memory_length=args.memory_length, |
|
|
deep_nlms=args.deep_memory, |
|
|
memory_hidden_dims=args.memory_hidden_dims, |
|
|
do_layernorm_nlm=args.do_normalisation, |
|
|
backbone_type=args.backbone_type, |
|
|
positional_embedding_type=args.positional_embedding_type, |
|
|
out_dims=args.out_dims, |
|
|
prediction_reshaper=prediction_reshaper, |
|
|
dropout=args.dropout, |
|
|
dropout_nlm=args.dropout_nlm, |
|
|
neuron_select_type=args.neuron_select_type, |
|
|
n_random_pairing_self=args.n_random_pairing_self, |
|
|
energy_head_enabled=args.energy_head_enabled, |
|
|
energy_hidden_dim=args.energy_hidden_dim, |
|
|
).to(device) |
|
|
elif args.model == 'lstm': |
|
|
model = LSTMBaseline( |
|
|
d_model=args.d_model, |
|
|
d_input=args.d_input, |
|
|
num_layers=args.num_layers, |
|
|
out_dims=args.out_dims, |
|
|
dropout=args.dropout, |
|
|
) |
|
|
elif args.model == 'ff': |
|
|
model = FFBaseline( |
|
|
d_model=args.d_model, |
|
|
d_input=args.d_input, |
|
|
out_dims=args.out_dims, |
|
|
dropout=args.dropout, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device) |
|
|
model(pseudo_inputs) |
|
|
|
|
|
print(f'Total params: {sum(p.numel() for p in model.parameters())}') |
|
|
decay_params = [] |
|
|
no_decay_params = [] |
|
|
no_decay_names = [] |
|
|
for name, param in model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list): |
|
|
no_decay_params.append(param) |
|
|
no_decay_names.append(name) |
|
|
else: |
|
|
decay_params.append(param) |
|
|
if len(no_decay_names): |
|
|
print(f'WARNING, excluding: {no_decay_names}') |
|
|
|
|
|
|
|
|
if len(no_decay_names) and args.weight_decay!=0: |
|
|
optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay}, |
|
|
{'params': no_decay_params, 'weight_decay':0}], |
|
|
lr=args.lr, |
|
|
eps=1e-8 if not args.use_amp else 1e-6) |
|
|
else: |
|
|
optimizer = torch.optim.AdamW(model.parameters(), |
|
|
lr=args.lr, |
|
|
eps=1e-8 if not args.use_amp else 1e-6, |
|
|
weight_decay=args.weight_decay) |
|
|
|
|
|
|
|
|
warmup_schedule = warmup(args.warmup_steps) |
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step) |
|
|
if args.use_scheduler: |
|
|
if args.scheduler_type == 'multistep': |
|
|
scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma) |
|
|
elif args.scheduler_type == 'cosine': |
|
|
scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
model, optimizer, trainloader, testloader, scheduler = accelerator.prepare( |
|
|
model, optimizer, trainloader, testloader, scheduler |
|
|
) |
|
|
|
|
|
|
|
|
start_iter = 0 |
|
|
train_losses = [] |
|
|
test_losses = [] |
|
|
train_accuracies = [] |
|
|
test_accuracies = [] |
|
|
iters = [] |
|
|
|
|
|
train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None |
|
|
test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None |
|
|
|
|
|
train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None |
|
|
test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.reload: |
|
|
checkpoint_path = f'{args.log_dir}/checkpoint.pt' |
|
|
if os.path.isfile(checkpoint_path): |
|
|
print(f'Reloading from: {checkpoint_path}') |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
if not args.strict_reload: print('WARNING: not using strict reload for model weights!') |
|
|
load_result = accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload) |
|
|
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}") |
|
|
|
|
|
if not args.reload_model_only: |
|
|
print('Reloading optimizer etc.') |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
|
|
|
start_iter = checkpoint['iteration'] |
|
|
|
|
|
train_losses = checkpoint['train_losses'] |
|
|
test_losses = checkpoint['test_losses'] |
|
|
train_accuracies = checkpoint['train_accuracies'] |
|
|
test_accuracies = checkpoint['test_accuracies'] |
|
|
iters = checkpoint['iters'] |
|
|
|
|
|
|
|
|
if args.model in ['ctm', 'lstm']: |
|
|
train_accuracies_most_certain = checkpoint['train_accuracies_most_certain'] |
|
|
test_accuracies_most_certain = checkpoint['test_accuracies_most_certain'] |
|
|
|
|
|
else: |
|
|
print('Only reloading model!') |
|
|
|
|
|
if 'torch_rng_state' in checkpoint: |
|
|
|
|
|
torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte()) |
|
|
np.random.set_state(checkpoint['numpy_rng_state']) |
|
|
random.setstate(checkpoint['random_rng_state']) |
|
|
|
|
|
del checkpoint |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
if args.do_compile: |
|
|
print('Compiling...') |
|
|
if hasattr(model, 'backbone'): |
|
|
model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True) |
|
|
|
|
|
|
|
|
if args.model == 'ctm': |
|
|
model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True) |
|
|
|
|
|
|
|
|
iterator = iter(trainloader) |
|
|
|
|
|
|
|
|
with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar: |
|
|
for bi in range(start_iter, args.training_iterations): |
|
|
current_lr = optimizer.param_groups[-1]['lr'] |
|
|
|
|
|
try: |
|
|
inputs, targets = next(iterator) |
|
|
except StopIteration: |
|
|
iterator = iter(trainloader) |
|
|
inputs, targets = next(iterator) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = None |
|
|
accuracy = None |
|
|
|
|
|
|
|
|
if args.do_compile: |
|
|
torch.compiler.cudagraph_mark_step_begin() |
|
|
|
|
|
if args.model == 'ctm': |
|
|
if args.energy_head_enabled: |
|
|
predictions, certainties, energies = model(inputs) |
|
|
if args.loss_type == 'energy_contrastive': |
|
|
criterion = EnergyContrastiveLoss(margin=args.energy_margin, energy_scale=args.energy_scale) |
|
|
loss, stats = criterion(predictions, energies, targets) |
|
|
|
|
|
where_most_certain = certainties[:,1].argmax(-1) |
|
|
accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item() |
|
|
pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Avg Energy={stats["avg_energy"]:0.3f}' |
|
|
else: |
|
|
|
|
|
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True) |
|
|
accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item() |
|
|
pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}' |
|
|
else: |
|
|
predictions, certainties, synchronisation = model(inputs) |
|
|
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True) |
|
|
accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item() |
|
|
pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})' |
|
|
|
|
|
elif args.model == 'lstm': |
|
|
predictions, certainties, synchronisation = model(inputs) |
|
|
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True) |
|
|
|
|
|
accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item() |
|
|
pbar_desc = f'LSTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})' |
|
|
|
|
|
elif args.model == 'ff': |
|
|
predictions = model(inputs) |
|
|
loss = nn.CrossEntropyLoss()(predictions, targets) |
|
|
accuracy = (predictions.argmax(1) == targets).float().mean().item() |
|
|
pbar_desc = f'FF Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}' |
|
|
|
|
|
|
|
|
accelerator.backward(loss) |
|
|
|
|
|
if args.gradient_clipping > 0: |
|
|
accelerator.clip_grad_norm_(model.parameters(), args.gradient_clipping) |
|
|
|
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
scheduler.step() |
|
|
|
|
|
pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}') |
|
|
|
|
|
|
|
|
|
|
|
if (bi % args.track_every == 0 or bi == args.warmup_steps) and (bi != 0 or args.reload_model_only): |
|
|
|
|
|
iters.append(bi) |
|
|
current_train_losses = [] |
|
|
current_test_losses = [] |
|
|
current_train_accuracies = [] |
|
|
current_test_accuracies = [] |
|
|
current_train_accuracies_most_certain = [] |
|
|
current_test_accuracies_most_certain = [] |
|
|
|
|
|
|
|
|
|
|
|
pbar.set_description('Resetting BN') |
|
|
model.train() |
|
|
for module in model.modules(): |
|
|
if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)): |
|
|
module.reset_running_stats() |
|
|
|
|
|
pbar.set_description('Tracking: Computing TRAIN metrics') |
|
|
with torch.no_grad(): |
|
|
|
|
|
all_targets_list = [] |
|
|
all_predictions_list = [] |
|
|
all_predictions_most_certain_list = [] |
|
|
all_losses = [] |
|
|
|
|
|
with tqdm(total=len(trainloader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner: |
|
|
for inferi, (inputs, targets) in enumerate(trainloader): |
|
|
|
|
|
|
|
|
all_targets_list.append(targets.detach().cpu().numpy()) |
|
|
|
|
|
|
|
|
if args.model == 'ctm': |
|
|
these_predictions, certainties, _ = model(inputs) |
|
|
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True) |
|
|
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) |
|
|
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) |
|
|
|
|
|
elif args.model == 'lstm': |
|
|
these_predictions, certainties, _ = model(inputs) |
|
|
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True) |
|
|
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) |
|
|
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) |
|
|
|
|
|
elif args.model == 'ff': |
|
|
these_predictions = model(inputs) |
|
|
loss = nn.CrossEntropyLoss()(these_predictions, targets) |
|
|
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) |
|
|
|
|
|
all_losses.append(loss.item()) |
|
|
|
|
|
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break |
|
|
pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})') |
|
|
pbar_inner.update(1) |
|
|
|
|
|
all_targets = np.concatenate(all_targets_list) |
|
|
all_predictions = np.concatenate(all_predictions_list) |
|
|
train_losses.append(np.mean(all_losses)) |
|
|
|
|
|
if args.model in ['ctm', 'lstm']: |
|
|
|
|
|
current_train_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0) |
|
|
train_accuracies.append(current_train_accuracies) |
|
|
|
|
|
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list) |
|
|
current_train_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean() |
|
|
train_accuracies_most_certain.append(current_train_accuracies_most_certain) |
|
|
else: |
|
|
current_train_accuracies = (all_targets == all_predictions).mean() |
|
|
train_accuracies.append(current_train_accuracies) |
|
|
|
|
|
del these_predictions |
|
|
|
|
|
|
|
|
|
|
|
model.eval() |
|
|
pbar.set_description('Tracking: Computing TEST metrics') |
|
|
with torch.inference_mode(): |
|
|
|
|
|
all_targets_list = [] |
|
|
all_predictions_list = [] |
|
|
all_predictions_most_certain_list = [] |
|
|
all_losses = [] |
|
|
|
|
|
with tqdm(total=len(testloader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner: |
|
|
for inferi, (inputs, targets) in enumerate(testloader): |
|
|
|
|
|
|
|
|
all_targets_list.append(targets.detach().cpu().numpy()) |
|
|
|
|
|
|
|
|
if args.model == 'ctm': |
|
|
these_predictions, certainties, _ = model(inputs) |
|
|
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True) |
|
|
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) |
|
|
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) |
|
|
|
|
|
elif args.model == 'lstm': |
|
|
these_predictions, certainties, _ = model(inputs) |
|
|
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True) |
|
|
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) |
|
|
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) |
|
|
|
|
|
elif args.model == 'ff': |
|
|
these_predictions = model(inputs) |
|
|
loss = nn.CrossEntropyLoss()(these_predictions, targets) |
|
|
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) |
|
|
|
|
|
all_losses.append(loss.item()) |
|
|
|
|
|
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break |
|
|
pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})') |
|
|
pbar_inner.update(1) |
|
|
|
|
|
all_targets = np.concatenate(all_targets_list) |
|
|
all_predictions = np.concatenate(all_predictions_list) |
|
|
test_losses.append(np.mean(all_losses)) |
|
|
|
|
|
if args.model in ['ctm', 'lstm']: |
|
|
current_test_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0) |
|
|
test_accuracies.append(current_test_accuracies) |
|
|
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list) |
|
|
current_test_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean() |
|
|
test_accuracies_most_certain.append(current_test_accuracies_most_certain) |
|
|
else: |
|
|
current_test_accuracies = (all_targets == all_predictions).mean() |
|
|
test_accuracies.append(current_test_accuracies) |
|
|
|
|
|
|
|
|
figacc = plt.figure(figsize=(10, 10)) |
|
|
axacc_train = figacc.add_subplot(211) |
|
|
axacc_test = figacc.add_subplot(212) |
|
|
cm = sns.color_palette("viridis", as_cmap=True) |
|
|
|
|
|
if args.model in ['ctm', 'lstm']: |
|
|
|
|
|
train_acc_arr = np.array(train_accuracies) |
|
|
test_acc_arr = np.array(test_accuracies) |
|
|
num_ticks = train_acc_arr.shape[1] |
|
|
for ti in range(num_ticks): |
|
|
axacc_train.plot(iters, train_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3) |
|
|
axacc_test.plot(iters, test_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3) |
|
|
|
|
|
axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain') |
|
|
axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain') |
|
|
else: |
|
|
axacc_train.plot(iters, train_accuracies, 'k-', alpha=0.7, label='Accuracy') |
|
|
axacc_test.plot(iters, test_accuracies, 'k-', alpha=0.7, label='Accuracy') |
|
|
|
|
|
axacc_train.set_title('Train Accuracy') |
|
|
axacc_test.set_title('Test Accuracy') |
|
|
axacc_train.legend(loc='lower right') |
|
|
axacc_test.legend(loc='lower right') |
|
|
axacc_train.set_xlim([0, args.training_iterations]) |
|
|
axacc_test.set_xlim([0, args.training_iterations]) |
|
|
if args.dataset=='cifar10': |
|
|
axacc_train.set_ylim([0.75, 1]) |
|
|
axacc_test.set_ylim([0.75, 1]) |
|
|
|
|
|
|
|
|
|
|
|
figacc.tight_layout() |
|
|
figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150) |
|
|
plt.close(figacc) |
|
|
|
|
|
figloss = plt.figure(figsize=(10, 5)) |
|
|
axloss = figloss.add_subplot(111) |
|
|
axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}') |
|
|
axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}') |
|
|
axloss.legend(loc='upper right') |
|
|
axloss.set_xlim([0, args.training_iterations]) |
|
|
axloss.set_ylim(bottom=0) |
|
|
|
|
|
figloss.tight_layout() |
|
|
figloss.savefig(f'{args.log_dir}/losses.png', dpi=150) |
|
|
plt.close(figloss) |
|
|
|
|
|
|
|
|
if args.model in ['ctm', 'lstm']: |
|
|
try: |
|
|
inputs_viz, targets_viz = next(iter(testloader)) |
|
|
|
|
|
|
|
|
|
|
|
pbar.set_description('Tracking: Processing test data for viz') |
|
|
predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True) |
|
|
|
|
|
att_shape = (accelerator.unwrap_model(model).kv_features.shape[2], accelerator.unwrap_model(model).kv_features.shape[3]) |
|
|
attention_tracking_viz = attention_tracking_viz.reshape( |
|
|
attention_tracking_viz.shape[0], |
|
|
attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1]) |
|
|
|
|
|
pbar.set_description('Tracking: Neural dynamics plot') |
|
|
plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True) |
|
|
|
|
|
imgi = 0 |
|
|
img_to_gif = np.moveaxis(np.clip(inputs_viz[imgi].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1) |
|
|
|
|
|
pbar.set_description('Tracking: Producing attention gif') |
|
|
make_classification_gif(img_to_gif, |
|
|
targets_viz[imgi].item(), |
|
|
predictions_viz[imgi].detach().cpu().numpy(), |
|
|
certainties_viz[imgi].detach().cpu().numpy(), |
|
|
post_activations_viz[:,imgi], |
|
|
attention_tracking_viz[:,imgi], |
|
|
class_labels, |
|
|
f'{args.log_dir}/{imgi}_attention.gif', |
|
|
) |
|
|
del predictions_viz, certainties_viz, pre_activations_viz, post_activations_viz, attention_tracking_viz |
|
|
except Exception as e: |
|
|
print(f"Visualization failed for model {args.model}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
model.train() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter: |
|
|
if accelerator.is_main_process: |
|
|
pbar.set_description('Saving model checkpoint...') |
|
|
checkpoint_data = { |
|
|
'model_state_dict': accelerator.unwrap_model(model).state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
'iteration': bi, |
|
|
'train_losses': train_losses, |
|
|
'test_losses': test_losses, |
|
|
'train_accuracies': train_accuracies, |
|
|
'test_accuracies': test_accuracies, |
|
|
'iters': iters, |
|
|
'args': args, |
|
|
'torch_rng_state': torch.get_rng_state(), |
|
|
'numpy_rng_state': np.random.get_state(), |
|
|
'random_rng_state': random.getstate(), |
|
|
} |
|
|
|
|
|
if args.model in ['ctm', 'lstm']: |
|
|
checkpoint_data['train_accuracies_most_certain'] = train_accuracies_most_certain |
|
|
checkpoint_data['test_accuracies_most_certain'] = test_accuracies_most_certain |
|
|
|
|
|
accelerator.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt') |
|
|
|
|
|
|
|
|
if args.push_to_hub and args.hub_model_id: |
|
|
if bi % (args.save_every * 5) == 0: |
|
|
try: |
|
|
upload_folder( |
|
|
folder_path=args.log_dir, |
|
|
repo_id=args.hub_model_id, |
|
|
token=args.hub_token, |
|
|
commit_message=f"Training checkpoint {bi}", |
|
|
ignore_patterns=["*.pt"], |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Failed to upload to hub: {e}") |
|
|
|
|
|
pbar.update(1) |
|
|
|