| | from typing import Dict, List, Union |
| | from transformers import BertPreTrainedModel, BertModel,PreTrainedTokenizer |
| | import torch.nn as nn |
| | import torch |
| | class BertForStorySkillClassification(BertPreTrainedModel): |
| | def __init__(self,config): |
| | super(BertForStorySkillClassification,self).__init__(config) |
| | self.num_labels = config.num_labels |
| | self.bert = BertModel(config) |
| | self.classifier = nn.Linear(config.hidden_size, self.num_labels) |
| | self.post_init() |
| |
|
| | def forward(self,input_ids,attention_mask=None,labels=None,**kwargs): |
| | outputs = self.bert(input_ids,attention_mask=attention_mask) |
| | cls_hidden_state = outputs.last_hidden_state[:,0,:] |
| | logits = self.classifier(cls_hidden_state) |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1,self.num_labels),labels.view(-1)) |
| | return loss |
| | return logits |
| | |
| |
|
| | def predict( |
| | self, |
| | texts: Union[str, List[str]], |
| | tokenizer: PreTrainedTokenizer, |
| | batch_size: int = 32, |
| | return_probabilities: bool = False, |
| | device: Union[str, torch.device] = 'cpu', |
| | ) -> List[Dict]: |
| | """ |
| | 对输入文本进行分类预测。 |
| | |
| | Args: |
| | texts: 单条文本或文本列表,例如 "故事中的角色是谁?" 或 ["问题1", "问题2"] |
| | tokenizer: 分词器实例(需与模型兼容) |
| | batch_size: 批处理大小(提升推理速度) |
| | return_probabilities: 是否返回概率值(默认返回标签) |
| | device: 指定设备(例如 "cuda" 或 "cpu"),默认自动检测模型当前设备 |
| | |
| | Returns: |
| | 预测结果列表,格式为: |
| | [{"text": "输入文本", "label": "预测标签", "score": 置信度}, ...] |
| | """ |
| | |
| | if device is None: |
| | device = self.device |
| |
|
| | |
| | if isinstance(texts, str): |
| | texts = [texts] |
| |
|
| | |
| | predictions = [] |
| |
|
| | |
| | with torch.no_grad(): |
| | for i in range(0, len(texts), batch_size): |
| | batch_texts = texts[i : i + batch_size] |
| |
|
| | |
| | inputs = tokenizer( |
| | batch_texts, |
| | padding=True, |
| | truncation=True, |
| | return_tensors="pt", |
| | max_length=512, |
| | ).to(device) |
| |
|
| | |
| | logits = self(**inputs) |
| | probs = torch.softmax(logits, dim=-1) |
| | scores, class_ids = torch.max(probs, dim=-1) |
| |
|
| | |
| | for text, class_id, score in zip(batch_texts, class_ids, scores): |
| | label = self.config.id2label[class_id.item()] |
| | result = {"text": text, "label": label} |
| | if return_probabilities: |
| | result["score"] = score.item() |
| | predictions.append(result) |
| |
|
| | return predictions |