File size: 1,205 Bytes
d04a061 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import torch.nn as nn
import torch.nn.functional as F
# Ignore file
class RDMCrossEntropyLoss(nn.CrossEntropyLoss):
def __init__(self, ignore_index):
self.ignore_index = ignore_index
def forward(self,
scores: torch.Tensor,
target: torch.Tensor,
label_mask,
weights,
) -> torch.Tensor:
"""
Computes the RDM-derived loss (weighted cross-entropy).
"""
sample_size = target.ne(self.ignore_index).float().sum()
lprobs = F.log_softmax(scores, dim=-1)
loss = lprobs * weights
fullseq_loss = loss.sum() / sample_size
# use coord masked loss for model training,
# ignoring those position with missing coords (as nan)
label_mask = label_mask.float()
sample_size = label_mask.sum() # sample size should be set to valid coordinates
loss = (loss * label_mask).sum() / sample_size
ppl = torch.exp(loss)
logging_output = {
'ppl': ppl.data,
'fullseq_loss': fullseq_loss.data,
'weight_diff_loss': loss.data
}
return logging_output |