Spaces:
Running
Running
| from torch import nn | |
| class SMTRLoss(nn.Module): | |
| def __init__(self, **kwargs): | |
| super(SMTRLoss, self).__init__() | |
| def forward(self, predicts, batch): | |
| if isinstance(predicts, list): | |
| predicts = predicts[0] | |
| return predicts | |