| | """ |
| | Context Encoder using pre-trained GuwenBERT RoBERTa. |
| | Implements the textual feature extraction module from the paper. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| |
|
| | class ContextEncoder(nn.Module): |
| | """ |
| | Context encoder using GuwenBERT RoBERTa large. |
| | Extracts features from masked positions in the text. |
| | """ |
| | |
| | def __init__(self, config, pretrained_model_name: str = None): |
| | """ |
| | Initialize context encoder. |
| | |
| | Args: |
| | config: Configuration object |
| | pretrained_model_name: HuggingFace model identifier |
| | """ |
| | super().__init__() |
| | self.config = config |
| | |
| | if pretrained_model_name is None: |
| | pretrained_model_name = config.roberta_model |
| | |
| | |
| | from transformers import logging as transformers_logging |
| | |
| | |
| | transformers_logging.set_verbosity_error() |
| | try: |
| | self.encoder = AutoModel.from_pretrained(pretrained_model_name, tie_word_embeddings=False) |
| | finally: |
| | transformers_logging.set_verbosity_warning() |
| |
|
| | self.hidden_dim = self.encoder.config.hidden_size |
| | |
| | |
| | |
| | |
| | |
| | def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Forward pass through RoBERTa. |
| | |
| | Args: |
| | input_ids: Token IDs [batch_size, seq_len] |
| | attention_mask: Attention mask [batch_size, seq_len] |
| | |
| | Returns: |
| | Hidden states [batch_size, seq_len, hidden_dim] |
| | """ |
| | outputs = self.encoder( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask |
| | ) |
| | |
| | |
| | return outputs.last_hidden_state |
| | |
| | def extract_mask_features( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor, |
| | mask_positions: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Extract features at masked positions. |
| | |
| | Args: |
| | input_ids: Token IDs [batch_size, seq_len] |
| | attention_mask: Attention mask [batch_size, seq_len] |
| | mask_positions: Positions of masks [batch_size, num_masks] |
| | |
| | Returns: |
| | Features at mask positions [batch_size, num_masks, hidden_dim] |
| | """ |
| | |
| | hidden_states = self.forward(input_ids, attention_mask) |
| | |
| | |
| | batch_size, num_masks = mask_positions.shape |
| | |
| | |
| | mask_positions_expanded = mask_positions.unsqueeze(-1).expand( |
| | batch_size, num_masks, self.hidden_dim |
| | ) |
| | |
| | |
| | mask_features = torch.gather(hidden_states, 1, mask_positions_expanded) |
| | |
| | return mask_features |
| | |
| | def freeze(self): |
| | """Freeze all parameters (for Phase 2 training).""" |
| | for param in self.parameters(): |
| | param.requires_grad = False |
| | |
| | def unfreeze(self): |
| | """Unfreeze all parameters.""" |
| | for param in self.parameters(): |
| | param.requires_grad = True |
| |
|