| | import torch.nn as nn |
| | from transformers import XLMRobertaModel |
| | from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaPreTrainedModel |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| |
|
| | class Smish(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | def forward(self, x): |
| | return x * (x.sigmoid() + 1).log().tanh() |
| |
|
| | class NoRefER(XLMRobertaPreTrainedModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | hidden_size = 32 |
| | self.config = config |
| | self.roberta = XLMRobertaModel(config) |
| | self.dense = nn.Sequential( |
| | nn.Dropout(config.hidden_dropout_prob), |
| | nn.Linear(config.hidden_size, hidden_size, bias = False), |
| | nn.Dropout(config.hidden_dropout_prob), Smish(), |
| | nn.Linear(hidden_size, 1, bias = False) |
| | ) |
| |
|
| | self.post_init() |
| |
|
| | def forward(self, positive_input_ids, positive_attention_mask, negative_input_ids, negative_attention_mask, labels, weight=None): |
| | |
| | positive_inputs = { |
| | "input_ids": positive_input_ids |
| | } |
| | positive = self.dense(self.roberta(**positive_inputs).pooler_output).squeeze(-1) |
| |
|
| | |
| | negative_inputs = { |
| | "input_ids": negative_input_ids |
| | } |
| | negative = self.dense(self.roberta(**negative_inputs).pooler_output).squeeze(-1) |
| |
|
| | if weight is None: |
| | bce = nn.BCEWithLogitsLoss() |
| | else: |
| | bs = len(positive) |
| | weights = (weight.float() * bs) / weight.sum() |
| | bce = nn.BCEWithLogitsLoss(weight = weights) |
| | loss = bce(positive - negative, labels.float()) |
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=positive.sigmoid()-negative.sigmoid(), |
| | ) |
| | |
| | def score( |
| | self, |
| | input_ids, |
| | attention_mask=None, |
| | token_type_ids=None, |
| | position_ids=None, |
| | head_mask=None, |
| | inputs_embeds=None, |
| | labels=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | ): |
| | h = self.roberta(input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states,).pooler_output |
| |
|
| | return self.dense(h).sigmoid().squeeze(-1) |
| |
|