File size: 1,205 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import torch.nn.functional as F

# Ignore file

class RDMCrossEntropyLoss(nn.CrossEntropyLoss):
    def __init__(self, ignore_index):
        self.ignore_index = ignore_index

    def forward(self,
                scores: torch.Tensor,
                target: torch.Tensor,
                label_mask,
                weights,
                ) -> torch.Tensor:
        """
        Computes the RDM-derived loss (weighted cross-entropy).
        """

        sample_size = target.ne(self.ignore_index).float().sum()

        lprobs = F.log_softmax(scores, dim=-1)

        loss = lprobs * weights
        fullseq_loss = loss.sum() / sample_size

        # use coord masked loss for model training,
        # ignoring those position with missing coords (as nan)
        label_mask = label_mask.float()
        sample_size = label_mask.sum()  # sample size should be set to valid coordinates
        loss = (loss * label_mask).sum() / sample_size

        ppl = torch.exp(loss)
        
        logging_output = {
            'ppl': ppl.data,
            'fullseq_loss': fullseq_loss.data,
            'weight_diff_loss': loss.data
        }

        return logging_output