|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def compute_ctc_loss(predictions, targets, blank_label=0): |
|
|
""" |
|
|
Computes the Connectionist Temporal Classification (CTC) loss. |
|
|
|
|
|
Args: |
|
|
predictions: A tensor of shape [B, C, L] representing the logits of the |
|
|
predicted sequences. B is the batch size, C is the number |
|
|
of classes (including the blank label), and L is the sequence |
|
|
length of the predictions. |
|
|
targets: A tensor of shape [B, T] representing the target sequences. |
|
|
B is the batch size and T is the target sequence length. |
|
|
Note that T can vary within the batch. |
|
|
blank_label: The index of the blank label. Defaults to 0. |
|
|
|
|
|
Returns: |
|
|
The CTC loss (a scalar tensor). |
|
|
""" |
|
|
|
|
|
batch_size, num_classes, prediction_length = predictions.shape |
|
|
_, target_length = targets.shape |
|
|
|
|
|
|
|
|
log_probs = F.log_softmax(predictions, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
log_probs = log_probs.permute(2, 0, 1) |
|
|
|
|
|
|
|
|
input_lengths = torch.full(size=(batch_size,), fill_value=prediction_length, dtype=torch.long) |
|
|
|
|
|
|
|
|
target_lengths = torch.tensor([t.shape[0] for t in targets], dtype=torch.long) |
|
|
|
|
|
|
|
|
ctc_loss = torch.nn.CTCLoss(blank=blank_label, reduction='mean') |
|
|
|
|
|
|
|
|
|
|
|
concatenated_targets = torch.cat(list(targets)) |
|
|
|
|
|
loss = ctc_loss(log_probs, concatenated_targets, input_lengths, target_lengths) |
|
|
|
|
|
return loss |
|
|
|
|
|
def sort_loss(predictions, targets): |
|
|
""" |
|
|
The sort task was used partly to show that ctc loss can work. |
|
|
""" |
|
|
loss = compute_ctc_loss(predictions, targets, blank_label=predictions.shape[1]-1) |
|
|
return loss |
|
|
|
|
|
def image_classification_loss(predictions, certainties, targets, use_most_certain=True): |
|
|
""" |
|
|
Computes the maze loss with auto-extending cirriculum. |
|
|
|
|
|
Predictions are of shape: (B, class, internal_ticks), |
|
|
Certainties are of shape: (B, 2, internal_ticks), |
|
|
where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy] |
|
|
Targets are of shape: [B] |
|
|
|
|
|
use_most_certain will select either the most certain point or the final point. |
|
|
""" |
|
|
targets_expanded = torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1) |
|
|
|
|
|
losses = nn.CrossEntropyLoss(reduction='none')(predictions, targets_expanded) |
|
|
|
|
|
loss_index_1 = losses.argmin(dim=1) |
|
|
loss_index_2 = certainties[:,1].argmax(-1) |
|
|
if not use_most_certain: |
|
|
loss_index_2[:] = -1 |
|
|
|
|
|
batch_indexer = torch.arange(predictions.size(0), device=predictions.device) |
|
|
loss_minimum_ce = losses[batch_indexer, loss_index_1].mean() |
|
|
loss_selected = losses[batch_indexer, loss_index_2].mean() |
|
|
|
|
|
loss = (loss_minimum_ce + loss_selected)/2 |
|
|
return loss, loss_index_2 |
|
|
|
|
|
def maze_loss(predictions, certainties, targets, cirriculum_lookahead=5, use_most_certain=True): |
|
|
""" |
|
|
Computes the maze loss with auto-extending cirriculum. |
|
|
|
|
|
Predictions are of shape: (B, route_length, class, internal_ticks), |
|
|
where classes are in [0,1,2,3,4] for [Up, Down, Left, Right, Wait] |
|
|
Certainties are of shape: (B, 2, internal_ticks), |
|
|
where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy] |
|
|
Targets are of shape: [B, route_length] |
|
|
|
|
|
cirriculum_lookahead: how far to look ahead in the auto-cirriculum |
|
|
|
|
|
use_most_certain will select either the most certain point or the final point. For baselines, |
|
|
the final point proved the only usable option. |
|
|
|
|
|
""" |
|
|
|
|
|
predictions_reshaped = predictions.flatten(0,1) |
|
|
|
|
|
targets_reshaped = torch.repeat_interleave(targets.unsqueeze(-1), |
|
|
predictions.size(-1), -1).flatten(0,1).long() |
|
|
|
|
|
|
|
|
losses = nn.CrossEntropyLoss(reduction='none')(predictions_reshaped, targets_reshaped) |
|
|
losses = losses.reshape(predictions[:,:,0].shape) |
|
|
|
|
|
|
|
|
|
|
|
iscorrects = (predictions.argmax(2) == targets.unsqueeze(-1)).cumsum(1) |
|
|
correct_mask = (iscorrects == torch.arange(1, iscorrects.size(1)+1, device=iscorrects.device).reshape(1, -1, 1)) |
|
|
correct_mask[:,0,:] = 1 |
|
|
upto_where = correct_mask.cumsum(1).argmax(1).max(-1)[0]+cirriculum_lookahead |
|
|
loss_mask = torch.zeros_like(losses) |
|
|
for bi in range(predictions.size(0)): |
|
|
loss_mask[bi, :upto_where[bi]] = 1 |
|
|
|
|
|
|
|
|
|
|
|
losses = (losses * loss_mask).sum(1)/(loss_mask.sum(1)) |
|
|
|
|
|
loss_index_1 = losses.argmin(dim=1) |
|
|
loss_index_2 = certainties[:,1].argmax(-1) |
|
|
if not use_most_certain: |
|
|
loss_index_2[:] = -1 |
|
|
|
|
|
batch_indexer = torch.arange(predictions.size(0), device=predictions.device) |
|
|
loss_minimum_ce = losses[batch_indexer, loss_index_1] |
|
|
loss_selected = losses[batch_indexer, loss_index_2] |
|
|
|
|
|
loss = ((loss_minimum_ce + loss_selected)/2).mean() |
|
|
return loss, loss_index_2, upto_where.detach().cpu().numpy() |
|
|
|
|
|
def parity_loss(predictions, certainties, targets, use_most_certain=True): |
|
|
""" |
|
|
Computes the parity loss. |
|
|
|
|
|
Predictions are of shape: (B, parity_sequence_length, class, internal_ticks), |
|
|
where classes are in [0,1,2,3,4] for [Up, Down, Left, Right, Wait] |
|
|
Certainties are of shape: (B, 2, internal_ticks), |
|
|
where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy] |
|
|
Targets are of shape: [B, parity_sequence_length] |
|
|
|
|
|
use_most_certain will select either the most certain point or the final point. For baselines, |
|
|
the final point proved the only usable option. |
|
|
""" |
|
|
|
|
|
|
|
|
losses = nn.CrossEntropyLoss(reduction='none')(predictions.flatten(0,1), |
|
|
torch.repeat_interleave(targets.unsqueeze(-1), |
|
|
predictions.size(-1), -1).flatten(0,1).long()).reshape(predictions[:,:,0].shape) |
|
|
|
|
|
|
|
|
losses = losses.mean(1) |
|
|
|
|
|
loss_index_1 = losses.argmin(dim=1) |
|
|
loss_index_2 = certainties[:,1].argmax(-1) |
|
|
if not use_most_certain: |
|
|
loss_index_2[:] = -1 |
|
|
|
|
|
batch_indexer = torch.arange(predictions.size(0), device=predictions.device) |
|
|
loss_minimum_ce = losses[batch_indexer, loss_index_1].mean() |
|
|
loss_selected = losses[batch_indexer, loss_index_2].mean() |
|
|
|
|
|
loss = (loss_minimum_ce + loss_selected)/2 |
|
|
return loss, loss_index_2 |
|
|
|
|
|
|
|
|
class EnergyContrastiveLoss(nn.Module): |
|
|
def __init__(self, margin=10.0, energy_scale=0.1): |
|
|
super().__init__() |
|
|
self.margin = margin |
|
|
self.energy_scale = energy_scale |
|
|
self.ce_loss = nn.CrossEntropyLoss(reduction='none') |
|
|
|
|
|
def forward(self, logits_history, energy_history, targets): |
|
|
""" |
|
|
logits_history: [B, Class, T] |
|
|
energy_history: [B, 1, T] |
|
|
targets: [B] |
|
|
""" |
|
|
B, C, T = logits_history.shape |
|
|
|
|
|
|
|
|
logits_flat = logits_history.permute(0, 2, 1).reshape(B * T, C) |
|
|
energy_flat = energy_history.permute(0, 2, 1).reshape(B * T) |
|
|
targets_expanded = targets.unsqueeze(1).repeat(1, T).reshape(B * T) |
|
|
|
|
|
|
|
|
ce_vals = self.ce_loss(logits_flat, targets_expanded) |
|
|
|
|
|
|
|
|
|
|
|
predictions = logits_flat.argmax(dim=1) |
|
|
is_correct = (predictions == targets_expanded).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_pos = energy_flat ** 2 |
|
|
|
|
|
|
|
|
loss_neg = F.relu(self.margin - energy_flat) ** 2 |
|
|
|
|
|
|
|
|
energy_objective = (is_correct * loss_pos) + ((1 - is_correct) * loss_neg) |
|
|
|
|
|
|
|
|
total_loss = ce_vals.mean() + (self.energy_scale * energy_objective.mean()) |
|
|
|
|
|
return total_loss, { |
|
|
"ce_loss": ce_vals.mean().item(), |
|
|
"energy_loss": energy_objective.mean().item(), |
|
|
"avg_energy": energy_flat.mean().item() |
|
|
} |
|
|
|
|
|
|
|
|
def qamnist_loss(predictions, certainties, targets, use_most_certain=True): |
|
|
""" |
|
|
Computes the qamnist loss over the last num_answer_steps steps. |
|
|
|
|
|
Predictions are of shape: (B, class, internal_ticks), |
|
|
Certainties are of shape: (B, 2, internal_ticks), |
|
|
where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy] |
|
|
Targets are of shape: [B] |
|
|
num_answer_steps: number of steps to consider for the loss |
|
|
|
|
|
use_most_certain will select either the most certain point or the final point. |
|
|
""" |
|
|
|
|
|
losses = nn.CrossEntropyLoss(reduction='none')(predictions, |
|
|
torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1)) |
|
|
|
|
|
loss_index_1 = losses.argmin(dim=1) |
|
|
loss_index_2 = certainties[:,1].argmax(-1) |
|
|
if not use_most_certain: |
|
|
loss_index_2[:] = -1 |
|
|
|
|
|
batch_indexer = torch.arange(predictions.size(0), device=predictions.device) |
|
|
loss_minimum_ce = losses[batch_indexer, loss_index_1].mean() |
|
|
loss_selected = losses[batch_indexer, loss_index_2].mean() |
|
|
|
|
|
loss = (loss_minimum_ce + loss_selected)/2 |
|
|
return loss, loss_index_2 |