LukeDarlow's picture
Welcome to the CTM. This is the first commit of the public repo. Enjoy!
68b32f4
raw
history blame
4.58 kB
import torch
import torch.nn.functional as F
import re
import os
def compute_decay(T, params, clamp_lims=(0, 15)):
"""
This function computes exponential decays for learnable synchronisation
interactions between pairs of neurons.
"""
assert len(clamp_lims), 'Clamp lims should be length 2'
assert type(clamp_lims) == tuple, 'Clamp lims should be tuple'
indices = torch.arange(T-1, -1, -1, device=params.device).reshape(T, 1).expand(T, params.shape[0])
out = torch.exp(-indices * torch.clamp(params, clamp_lims[0], clamp_lims[1]).unsqueeze(0))
return out
def add_coord_dim(x, scaled=True):
"""
Adds a final dimension to the tensor representing 2D coordinates.
Args:
tensor: A PyTorch tensor of shape (B, D, H, W).
Returns:
A PyTorch tensor of shape (B, D, H, W, 2) with the last dimension
representing the 2D coordinates within the HW dimensions.
"""
B, H, W = x.shape
# Create coordinate grids
x_coords = torch.arange(W, device=x.device, dtype=x.dtype).repeat(H, 1) # Shape (H, W)
y_coords = torch.arange(H, device=x.device, dtype=x.dtype).unsqueeze(-1).repeat(1, W) # Shape (H, W)
if scaled:
x_coords /= (W-1)
y_coords /= (H-1)
# Stack coordinates and expand dimensions
coords = torch.stack((x_coords, y_coords), dim=-1) # Shape (H, W, 2)
coords = coords.unsqueeze(0) # Shape (1, 1, H, W, 2)
coords = coords.repeat(B, 1, 1, 1) # Shape (B, D, H, W, 2)
return coords
def compute_normalized_entropy(logits, reduction='mean'):
"""
Calculates the normalized entropy of a PyTorch tensor of logits along the
final dimension.
Args:
logits: A PyTorch tensor of logits.
Returns:
A PyTorch tensor containing the normalized entropy values.
"""
# Apply softmax to get probabilities
preds = F.softmax(logits, dim=-1)
# Calculate the log probabilities
log_preds = torch.log_softmax(logits, dim=-1)
# Calculate the entropy
entropy = -torch.sum(preds * log_preds, dim=-1)
# Calculate the maximum possible entropy
num_classes = preds.shape[-1]
max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32))
# Normalize the entropy
normalized_entropy = entropy / max_entropy
if len(logits.shape)>2 and reduction == 'mean':
normalized_entropy = normalized_entropy.flatten(1).mean(-1)
return normalized_entropy
def reshape_predictions(predictions, prediction_reshaper):
B, T = predictions.size(0), predictions.size(-1)
new_shape = [B] + prediction_reshaper + [T]
rehaped_predictions = predictions.reshape(new_shape)
return rehaped_predictions
def get_all_log_dirs(root_dir):
folders = []
for dirpath, dirnames, filenames in os.walk(root_dir):
if any(f.endswith(".pt") for f in filenames):
folders.append(dirpath)
return folders
def get_latest_checkpoint(log_dir):
files = [f for f in os.listdir(log_dir) if re.match(r'checkpoint_\d+\.pt', f)]
return os.path.join(log_dir, max(files, key=lambda f: int(re.search(r'\d+', f).group()))) if files else None
def get_latest_checkpoint_file(filepath, limit=300000):
checkpoint_files = get_checkpoint_files(filepath)
checkpoint_files = [
f for f in checkpoint_files if int(re.search(r'checkpoint_(\d+)\.pt', f).group(1)) <= limit
]
if not checkpoint_files:
return None
return checkpoint_files[-1]
def get_checkpoint_files(filepath):
regex = r'checkpoint_(\d+)\.pt'
files = [f for f in os.listdir(filepath) if re.match(regex, f)]
files = sorted(files, key=lambda f: int(re.search(regex, f).group(1)))
return [os.path.join(filepath, f) for f in files]
def load_checkpoint(checkpoint_path, device):
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
return checkpoint
def get_model_args_from_checkpoint(checkpoint):
if "args" in checkpoint:
return(checkpoint["args"])
else:
raise ValueError("Checkpoint does not contain saved args.")
def get_accuracy_and_loss_from_checkpoint(checkpoint, device="cpu"):
training_iteration = checkpoint.get('training_iteration', 0)
train_losses = checkpoint.get('train_losses', [])
test_losses = checkpoint.get('test_losses', [])
train_accuracies = checkpoint.get('train_accuracies_most_certain', [])
test_accuracies = checkpoint.get('test_accuracies_most_certain', [])
return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies