Uday's picture
Thought Depth via Energy Minimization: halting with a learned Energy scalar.
80dd9c4
raw
history blame
10.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_ctc_loss(predictions, targets, blank_label=0):
"""
Computes the Connectionist Temporal Classification (CTC) loss.
Args:
predictions: A tensor of shape [B, C, L] representing the logits of the
predicted sequences. B is the batch size, C is the number
of classes (including the blank label), and L is the sequence
length of the predictions.
targets: A tensor of shape [B, T] representing the target sequences.
B is the batch size and T is the target sequence length.
Note that T can vary within the batch.
blank_label: The index of the blank label. Defaults to 0.
Returns:
The CTC loss (a scalar tensor).
"""
batch_size, num_classes, prediction_length = predictions.shape
_, target_length = targets.shape
# 1. Log softmax on predictions: Crucially, CTC loss requires log probabilities.
log_probs = F.log_softmax(predictions, dim=1) # Shape: [B, C, L]
# 2. Prepare inputs for torch.nn.CTCLoss:
# a. Convert log_probs to shape (L, B, C): CTCLoss expects time first.
log_probs = log_probs.permute(2, 0, 1) # Shape: [L, B, C]
# b. Get lengths of the predicted sequences (all L in this case).
input_lengths = torch.full(size=(batch_size,), fill_value=prediction_length, dtype=torch.long)
# c. Get lengths of the target sequences.
target_lengths = torch.tensor([t.shape[0] for t in targets], dtype=torch.long) # Handle variable target lengths
# 3. Create the CTCLoss criterion. `blank=blank_label` is essential!
ctc_loss = torch.nn.CTCLoss(blank=blank_label, reduction='mean') # 'mean' for averaging over the batch
# 4. Calculate the loss. `targets` needs to be a concatenated tensor.
# We handle padding by only passing the valid lengths to CTCLoss.
concatenated_targets = torch.cat(list(targets)) # Concatenate targets
loss = ctc_loss(log_probs, concatenated_targets, input_lengths, target_lengths)
return loss
def sort_loss(predictions, targets):
"""
The sort task was used partly to show that ctc loss can work.
"""
loss = compute_ctc_loss(predictions, targets, blank_label=predictions.shape[1]-1)
return loss
def image_classification_loss(predictions, certainties, targets, use_most_certain=True):
"""
Computes the maze loss with auto-extending cirriculum.
Predictions are of shape: (B, class, internal_ticks),
Certainties are of shape: (B, 2, internal_ticks),
where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy]
Targets are of shape: [B]
use_most_certain will select either the most certain point or the final point.
"""
targets_expanded = torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1)
# Losses are of shape [B, internal_ticks]
losses = nn.CrossEntropyLoss(reduction='none')(predictions, targets_expanded)
loss_index_1 = losses.argmin(dim=1)
loss_index_2 = certainties[:,1].argmax(-1)
if not use_most_certain: # Revert to final loss if set
loss_index_2[:] = -1
batch_indexer = torch.arange(predictions.size(0), device=predictions.device)
loss_minimum_ce = losses[batch_indexer, loss_index_1].mean()
loss_selected = losses[batch_indexer, loss_index_2].mean()
loss = (loss_minimum_ce + loss_selected)/2
return loss, loss_index_2
def maze_loss(predictions, certainties, targets, cirriculum_lookahead=5, use_most_certain=True):
"""
Computes the maze loss with auto-extending cirriculum.
Predictions are of shape: (B, route_length, class, internal_ticks),
where classes are in [0,1,2,3,4] for [Up, Down, Left, Right, Wait]
Certainties are of shape: (B, 2, internal_ticks),
where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy]
Targets are of shape: [B, route_length]
cirriculum_lookahead: how far to look ahead in the auto-cirriculum
use_most_certain will select either the most certain point or the final point. For baselines,
the final point proved the only usable option.
"""
# Predictions reshaped to: [B*route_length, 5, internal_ticks]
predictions_reshaped = predictions.flatten(0,1)
# Targets reshaped to: [B*route_length, internal_ticks]
targets_reshaped = torch.repeat_interleave(targets.unsqueeze(-1),
predictions.size(-1), -1).flatten(0,1).long()
# Losses are of shape [B, route_length, internal_ticks]
losses = nn.CrossEntropyLoss(reduction='none')(predictions_reshaped, targets_reshaped)
losses = losses.reshape(predictions[:,:,0].shape)
# Below is the code for auto-cirriculum
# Find where correct, and make sure to always push +5 beyond that
iscorrects = (predictions.argmax(2) == targets.unsqueeze(-1)).cumsum(1)
correct_mask = (iscorrects == torch.arange(1, iscorrects.size(1)+1, device=iscorrects.device).reshape(1, -1, 1))
correct_mask[:,0,:] = 1
upto_where = correct_mask.cumsum(1).argmax(1).max(-1)[0]+cirriculum_lookahead
loss_mask = torch.zeros_like(losses)
for bi in range(predictions.size(0)):
loss_mask[bi, :upto_where[bi]] = 1
# Reduce losses along route dimension
# Will now be of shape [B, internal_ticks]
losses = (losses * loss_mask).sum(1)/(loss_mask.sum(1))
loss_index_1 = losses.argmin(dim=1)
loss_index_2 = certainties[:,1].argmax(-1)
if not use_most_certain:
loss_index_2[:] = -1
batch_indexer = torch.arange(predictions.size(0), device=predictions.device)
loss_minimum_ce = losses[batch_indexer, loss_index_1]
loss_selected = losses[batch_indexer, loss_index_2]
loss = ((loss_minimum_ce + loss_selected)/2).mean()
return loss, loss_index_2, upto_where.detach().cpu().numpy()
def parity_loss(predictions, certainties, targets, use_most_certain=True):
"""
Computes the parity loss.
Predictions are of shape: (B, parity_sequence_length, class, internal_ticks),
where classes are in [0,1,2,3,4] for [Up, Down, Left, Right, Wait]
Certainties are of shape: (B, 2, internal_ticks),
where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy]
Targets are of shape: [B, parity_sequence_length]
use_most_certain will select either the most certain point or the final point. For baselines,
the final point proved the only usable option.
"""
# Losses are of shape [B, parity_sequence_length, internal_ticks]
losses = nn.CrossEntropyLoss(reduction='none')(predictions.flatten(0,1),
torch.repeat_interleave(targets.unsqueeze(-1),
predictions.size(-1), -1).flatten(0,1).long()).reshape(predictions[:,:,0].shape)
# Average the loss over the parity sequenece dimension
losses = losses.mean(1)
loss_index_1 = losses.argmin(dim=1)
loss_index_2 = certainties[:,1].argmax(-1)
if not use_most_certain:
loss_index_2[:] = -1
batch_indexer = torch.arange(predictions.size(0), device=predictions.device)
loss_minimum_ce = losses[batch_indexer, loss_index_1].mean()
loss_selected = losses[batch_indexer, loss_index_2].mean()
loss = (loss_minimum_ce + loss_selected)/2
return loss, loss_index_2
class EnergyContrastiveLoss(nn.Module):
def __init__(self, margin=10.0, energy_scale=0.1):
super().__init__()
self.margin = margin
self.energy_scale = energy_scale
self.ce_loss = nn.CrossEntropyLoss(reduction='none')
def forward(self, logits_history, energy_history, targets):
"""
logits_history: [B, Class, T]
energy_history: [B, 1, T]
targets: [B]
"""
B, C, T = logits_history.shape
# Flatten for easy computation
logits_flat = logits_history.permute(0, 2, 1).reshape(B * T, C)
energy_flat = energy_history.permute(0, 2, 1).reshape(B * T)
targets_expanded = targets.unsqueeze(1).repeat(1, T).reshape(B * T)
# 1. Standard Classification Loss (Cross Entropy)
ce_vals = self.ce_loss(logits_flat, targets_expanded)
# 2. Determine "Correctness" for Contrastive Divergence
# We treat a step as "positive" if the prediction matches the target
predictions = logits_flat.argmax(dim=1)
is_correct = (predictions == targets_expanded).float() # 1.0 if correct, 0.0 if wrong
# 3. Energy Loss Logic
# If Correct: Minimize Energy (Pull down to 0)
# If Incorrect: Maximize Energy (Push up to margin)
# L_pos = ||E(x)||^2 (Push correct states to 0 energy)
loss_pos = energy_flat ** 2
# L_neg = max(0, m - E(x))^2 (Push incorrect states above margin m)
loss_neg = F.relu(self.margin - energy_flat) ** 2
# Combine: correct samples use loss_pos, incorrect use loss_neg
energy_objective = (is_correct * loss_pos) + ((1 - is_correct) * loss_neg)
# Total Loss
total_loss = ce_vals.mean() + (self.energy_scale * energy_objective.mean())
return total_loss, {
"ce_loss": ce_vals.mean().item(),
"energy_loss": energy_objective.mean().item(),
"avg_energy": energy_flat.mean().item()
}
def qamnist_loss(predictions, certainties, targets, use_most_certain=True):
"""
Computes the qamnist loss over the last num_answer_steps steps.
Predictions are of shape: (B, class, internal_ticks),
Certainties are of shape: (B, 2, internal_ticks),
where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy]
Targets are of shape: [B]
num_answer_steps: number of steps to consider for the loss
use_most_certain will select either the most certain point or the final point.
"""
losses = nn.CrossEntropyLoss(reduction='none')(predictions,
torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1))
loss_index_1 = losses.argmin(dim=1)
loss_index_2 = certainties[:,1].argmax(-1)
if not use_most_certain:
loss_index_2[:] = -1
batch_indexer = torch.arange(predictions.size(0), device=predictions.device)
loss_minimum_ce = losses[batch_indexer, loss_index_1].mean()
loss_selected = losses[batch_indexer, loss_index_2].mean()
loss = (loss_minimum_ce + loss_selected)/2
return loss, loss_index_2