|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import config
|
|
|
| class CrossAttentionDelta(nn.Module):
|
| """
|
| Enhanced version of CrossAttentionDelta that computes the update delta (Δ) using cross-attention.
|
| Improvements:
|
| 1. Pre-norm architecture (layer norm before attention)
|
| 2. More sophisticated attention patterns
|
| 3. Ability to incorporate reasoning trace
|
| """
|
| def __init__(self, hidden_dim, num_heads=8, dropout=0.1):
|
| super().__init__()
|
| self.hidden_dim = hidden_dim
|
| self.num_heads = num_heads
|
|
|
|
|
| self.pre_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
| self.cross_attn = nn.MultiheadAttention(
|
| embed_dim=hidden_dim,
|
| num_heads=num_heads,
|
| dropout=dropout,
|
| batch_first=True
|
| )
|
|
|
|
|
| self.post_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
| self.trace_integration = nn.Sequential(
|
| nn.Linear(hidden_dim * 2, hidden_dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden_dim, hidden_dim)
|
| )
|
|
|
|
|
| self.delta_mlp = nn.Sequential(
|
| nn.Linear(hidden_dim * 2, hidden_dim * 4),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden_dim * 4, hidden_dim * 2),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden_dim * 2, hidden_dim)
|
| )
|
|
|
|
|
| self.final_norm = nn.LayerNorm(hidden_dim)
|
|
|
| def forward(self, h0, reasoning_trace=None):
|
| """
|
| Args:
|
| h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
|
| reasoning_trace (tuple of torch.Tensor, optional): Reasoning trace from base model.
|
| Each tensor has shape (batch_size, seq_len, hidden_dim).
|
|
|
| Returns:
|
| delta (torch.Tensor): The computed update delta (batch_size, seq_len, hidden_dim).
|
| """
|
| batch_size, seq_len, _ = h0.shape
|
|
|
|
|
|
|
| h0_norm = self.pre_norm(h0)
|
|
|
|
|
|
|
| attn_output, attn_weights = self.cross_attn(
|
| query=h0_norm,
|
| key=h0_norm,
|
| value=h0_norm,
|
| need_weights=True
|
| )
|
|
|
|
|
| c = self.post_norm(h0 + attn_output)
|
|
|
|
|
| if reasoning_trace is not None and len(reasoning_trace) > 0:
|
|
|
| last_layer = reasoning_trace[-1]
|
|
|
|
|
| trace_info = self.trace_integration(
|
| torch.cat([c, last_layer], dim=-1)
|
| )
|
|
|
|
|
| c = c + trace_info
|
|
|
|
|
|
|
| mlp_input = torch.cat((h0, c), dim=-1)
|
|
|
|
|
| delta = self.delta_mlp(mlp_input)
|
|
|
|
|
| delta = self.final_norm(delta)
|
|
|
| return delta, attn_weights
|
|
|
| class GatingMechanism(nn.Module):
|
| """
|
| Gating mechanism to selectively apply updates.
|
| Learns when to apply the delta update based on the hidden state and delta.
|
| """
|
| def __init__(self, hidden_dim, dropout=0.1):
|
| super().__init__()
|
| self.gate_network = nn.Sequential(
|
| nn.Linear(hidden_dim * 2, hidden_dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden_dim, hidden_dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden_dim, 1),
|
| nn.Sigmoid()
|
| )
|
|
|
| def forward(self, h0, delta):
|
| """
|
| Args:
|
| h0 (torch.Tensor): Initial hidden states (batch_size, seq_len, hidden_dim).
|
| delta (torch.Tensor): Computed delta (batch_size, seq_len, hidden_dim).
|
|
|
| Returns:
|
| gate (torch.Tensor): Gate values between 0 and 1 (batch_size, seq_len, 1).
|
| """
|
|
|
| gate_input = torch.cat([h0, delta], dim=-1)
|
|
|
|
|
| gate = self.gate_network(gate_input)
|
|
|
| return gate
|
|
|
| class EnhancedQAHead(nn.Module):
|
| """
|
| Enhanced Question Answering head with deeper architecture and bilinear scoring.
|
| """
|
| def __init__(self, hidden_dim, dropout=0.1):
|
| super().__init__()
|
|
|
|
|
| self.start_transform = nn.Sequential(
|
| nn.Linear(hidden_dim, hidden_dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden_dim, hidden_dim)
|
| )
|
|
|
| self.end_transform = nn.Sequential(
|
| nn.Linear(hidden_dim, hidden_dim),
|
| nn.GELU(),
|
| nn.Dropout(dropout),
|
| nn.Linear(hidden_dim, hidden_dim)
|
| )
|
|
|
|
|
| self.start_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
|
|
|
|
|
| self.end_bilinear = nn.Bilinear(hidden_dim, hidden_dim, 1)
|
|
|
|
|
| self.global_rep = nn.Parameter(torch.randn(hidden_dim))
|
|
|
| def forward(self, hidden_states):
|
| """
|
| Args:
|
| hidden_states (torch.Tensor): Hidden states (batch_size, seq_len, hidden_dim).
|
|
|
| Returns:
|
| dict: Dictionary with start_logits and end_logits.
|
| """
|
| batch_size, seq_len, hidden_dim = hidden_states.shape
|
|
|
|
|
| start_rep = self.start_transform(hidden_states)
|
| end_rep = self.end_transform(hidden_states)
|
|
|
|
|
| global_rep = self.global_rep.expand(batch_size, seq_len, -1)
|
|
|
|
|
| start_logits = self.start_bilinear(start_rep, global_rep).squeeze(-1)
|
| end_logits = self.end_bilinear(end_rep, global_rep).squeeze(-1)
|
|
|
| return {"start_logits": start_logits, "end_logits": end_logits}
|
|
|