Spaces:
Running
Running
| from torch import nn | |
| class MGPLoss(nn.Module): | |
| def __init__(self, only_char=False, **kwargs): | |
| super(MGPLoss, self).__init__() | |
| self.ce = nn.CrossEntropyLoss(reduction='mean', ignore_index=0) | |
| self.only_char = only_char | |
| def forward(self, pred, batch): | |
| if self.only_char: | |
| char_feats = pred | |
| char_tgt = batch[1].flatten(0, 1) | |
| char_loss = self.ce(char_feats.flatten(0, 1), char_tgt) | |
| return {'loss': char_loss} | |
| else: | |
| return self.forward_all(pred, batch) | |
| def forward_all(self, pred, batch): | |
| char_feats, dpe_feats, wp_feats = pred | |
| char_tgt = batch[1].flatten(0, 1) | |
| dpe_tgt = batch[2].flatten(0, 1) | |
| wp_tgt = batch[3].flatten(0, 1) | |
| char_loss = self.ce(char_feats.flatten(0, 1), char_tgt) | |
| dpe_loss = self.ce(dpe_feats.flatten(0, 1), dpe_tgt) | |
| wp_loss = self.ce(wp_feats.flatten(0, 1), wp_tgt) | |
| loss = char_loss + dpe_loss + wp_loss | |
| return { | |
| 'loss': loss, | |
| 'char_loss': char_loss, | |
| 'dpe_loss': dpe_loss, | |
| 'wp_loss': wp_loss | |
| } | |