|
|
""" |
|
|
LAPEFT Model Loading Utilities |
|
|
|
|
|
This script provides utilities to load and use the LAPEFT model with all its components. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
|
from peft import get_peft_model, LoraConfig, TaskType, PeftModel |
|
|
import pickle |
|
|
import os |
|
|
|
|
|
class GatedFusion(nn.Module): |
|
|
def __init__(self, hidden_size, lexicon_size=64): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
self.lexicon_size = lexicon_size |
|
|
self.transformer_gate = nn.Linear(hidden_size + lexicon_size, hidden_size) |
|
|
self.lexicon_gate = nn.Linear(hidden_size + lexicon_size, hidden_size) |
|
|
self.transformer_transform = nn.Linear(hidden_size, hidden_size) |
|
|
self.lexicon_transform = nn.Linear(lexicon_size, hidden_size) |
|
|
self.output_projection = nn.Linear(hidden_size, hidden_size) |
|
|
self.layer_norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
def forward(self, transformer_output, lexicon_features): |
|
|
combined = torch.cat([transformer_output, lexicon_features], dim=1) |
|
|
transformer_gate = torch.sigmoid(self.transformer_gate(combined)) |
|
|
lexicon_gate = torch.sigmoid(self.lexicon_gate(combined)) |
|
|
transformer_rep = self.transformer_transform(transformer_output) |
|
|
lexicon_rep = self.lexicon_transform(lexicon_features) |
|
|
gated_transformer = transformer_rep * transformer_gate |
|
|
gated_lexicon = lexicon_rep * lexicon_gate |
|
|
fused = gated_transformer + gated_lexicon |
|
|
output = self.output_projection(fused) |
|
|
output = self.layer_norm(output) |
|
|
return F.relu(output) |
|
|
|
|
|
class MemoryOptimizedLexiconBERT(nn.Module): |
|
|
def __init__(self, base_model, peft_config, num_labels=3): |
|
|
super().__init__() |
|
|
self.transformer = get_peft_model(base_model, peft_config) |
|
|
self.transformer.config.output_hidden_states = False |
|
|
self.hidden_size = self.transformer.config.hidden_size |
|
|
self.num_labels = num_labels |
|
|
self.lexicon_projection = nn.Linear(4, 64) |
|
|
self.fusion_mechanism = GatedFusion(self.hidden_size, 64) |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(self.hidden_size, self.hidden_size // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(self.hidden_size // 2, num_labels) |
|
|
) |
|
|
|
|
|
def load_lapeft_model(model_path): |
|
|
"""Load the complete LAPEFT model with all components""" |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
base_model = BertForSequenceClassification.from_pretrained( |
|
|
"bert-base-uncased", |
|
|
num_labels=3, |
|
|
output_hidden_states=False |
|
|
) |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, model_path) |
|
|
|
|
|
|
|
|
lexicon_path = os.path.join(model_path, 'lexicon_analyzer.pkl') |
|
|
lexicon_analyzer = None |
|
|
if os.path.exists(lexicon_path): |
|
|
with open(lexicon_path, 'rb') as f: |
|
|
lexicon_analyzer = pickle.load(f) |
|
|
|
|
|
return model, tokenizer, lexicon_analyzer |
|
|
|
|
|
|
|
|
|
|
|
|