| import torch |
| import torch.nn.functional as F |
|
|
| from transformers import PreTrainedModel |
| from .original import TransformerModel, LMHead |
| ''' |
| Code for HuggingFace Hub Compatability |
| ''' |
|
|
| class HF_LMModel(PreTrainedModel): |
| """ Transformer with language model head only """ |
| def __init__(self, config): |
| super().__init__(config) |
| self.transformer = TransformerModel(config, vocab=config.n_vocab, n_ctx=config.n_ctx) |
| self.lm_head = LMHead(self.transformer, config, trunc_and_reshape=False) |
| self.return_probs = config.return_probs |
| self.return_acts = config.return_acts |
| if self.return_probs or self.return_acts: |
| pos_emb_mask = torch.zeros(1, 1, config.n_vocab) |
| pos_emb_mask[:, :, -config.n_ctx:] = -1e12 |
| self.register_buffer('pos_emb_mask', pos_emb_mask) |
|
|
| def forward(self, x, sequence_mask=None): |
| h = self.transformer(x, sequence_mask) |
| lm_logits = self.lm_head(h) |
| if self.return_probs: |
| lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1) |
| elif self.return_acts: |
| lm_logits = lm_logits + self.pos_emb_mask |
| return { "logits": lm_logits } |