| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from .configuration_tinymodel import TinyModelConfig |
|
|
| class TinyCore(nn.Module): |
| """Your original TinyModel, but embedded here for convenience.""" |
| def __init__(self, cfg: TinyModelConfig): |
| super().__init__() |
| self.linear1 = nn.Linear(cfg.input_size, cfg.hidden_size) |
| self.activation = nn.ReLU() |
| self.linear2 = nn.Linear(cfg.hidden_size, cfg.num_labels) |
| self.softmax = nn.Softmax(dim=-1) |
|
|
| def forward(self, x: torch.Tensor): |
| x = self.linear1(x) |
| x = self.activation(x) |
| x = self.linear2(x) |
| x = self.softmax(x) |
| return x |
|
|
| class TinyModel(PreTrainedModel): |
| config_class = TinyModelConfig |
|
|
| def __init__(self, config: TinyModelConfig): |
| super().__init__(config) |
| self.core = TinyCore(config) |
| self.post_init() |
|
|
| def forward(self, inputs: torch.Tensor, **kwargs): |
| """ |
| Expect inputs shape: (batch, config.input_size) |
| """ |
| return self.core(inputs) |
|
|
| |
| def predict_proba(self, inputs: torch.Tensor): |
| return self.forward(inputs) |
|
|