Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| def flatten_label(target): | |
| label_flatten = [] | |
| label_length = [] | |
| for i in range(0, target.size()[0]): | |
| cur_label = target[i].tolist() | |
| label_flatten += cur_label[:cur_label.index(0) + 1] | |
| label_length.append(cur_label.index(0) + 1) | |
| label_flatten = torch.LongTensor(label_flatten) | |
| label_length = torch.IntTensor(label_length) | |
| return (label_flatten, label_length) | |
| def _flatten(sources, lengths): | |
| return torch.cat([t[:l] for t, l in zip(sources, lengths)]) | |
| class VisionLANLoss(nn.Module): | |
| def __init__(self, | |
| training_step='LA', | |
| ratio_res=0.5, | |
| ratio_sub=0.5, | |
| **kwargs): | |
| super(VisionLANLoss, self).__init__() | |
| self.loss_func = nn.CrossEntropyLoss(reduction='mean') | |
| self.ratio_res = ratio_res | |
| self.ratio_sub = ratio_sub | |
| assert training_step in ['LF_1', 'LF_2', 'LA'] | |
| self.training_step = training_step | |
| def forward(self, pred, batch): | |
| text_pre, text_rem, text_mas, _ = pred | |
| target = batch[1].to(dtype=torch.int64) | |
| label_flatten, length = flatten_label(target) | |
| text_pre = _flatten(text_pre, length) | |
| if self.training_step == 'LF_1': | |
| loss = self.loss_func(text_pre, label_flatten.to(text_pre.device)) | |
| else: | |
| target_res = batch[2].to(dtype=torch.int64) | |
| target_sub = batch[3].to(dtype=torch.int64) | |
| label_flatten_res, length_res = flatten_label(target_res) | |
| label_flatten_sub, length_sub = flatten_label(target_sub) | |
| text_rem = _flatten(text_rem, length_res) | |
| text_mas = _flatten(text_mas, length_sub) | |
| loss_ori = self.loss_func(text_pre, | |
| label_flatten.to(text_pre.device)) | |
| loss_res = self.loss_func(text_rem, | |
| label_flatten_res.to(text_rem.device)) | |
| loss_mas = self.loss_func(text_mas, | |
| label_flatten_sub.to(text_mas.device)) | |
| loss = loss_ori + loss_res * self.ratio_res + loss_mas * self.ratio_sub | |
| return {'loss': loss} | |