File size: 5,488 Bytes
3a2194a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | """
model.py — MATCHA contrastive model architecture.
ContrastiveModel wraps a pretrained language model backbone and adds a
SenseNetwork that decomposes word embeddings into multiple "sense" vectors,
followed by a learned transformation and mean-pooling to produce a single
sentence embedding for contrastive learning.
"""
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from transformers.activations import ACT2FN
from typing import Optional, Tuple
class ContrastiveModel(nn.Module):
"""Top-level model: backbone word embeddings -> SenseNetwork -> projection.
Args:
contxtl_model: Pretrained HuggingFace model used only for its embedding layer.
config: SimpleNamespace with model_type, n_embd, num_senses, etc.
"""
def __init__(self, contxtl_model, config):
super().__init__()
self.sense_network = SenseNetwork(config)
self.contxtl_model = contxtl_model
# Extract the word embedding layer from the backbone
if config.model_type in ['gpt2', 'gpt_neo', 'roberta', 'xlm-roberta']:
self.word_embeddings = self.contxtl_model.get_input_embeddings()
elif config.model_type in ['mistral']:
self.word_embeddings = self.contxtl_model.model.embed_tokens
# Learnable transformation applied to sense vectors before pooling
self.transformation_matrix = nn.Parameter(torch.randn(config.n_embd, config.n_embd))
def get_model_output(self, input_ids):
"""Compute multi-sense embeddings from token IDs."""
sense_input_embeds = self.word_embeddings(input_ids) # (bs, s, d)
senses = self.sense_network(sense_input_embeds) # (bs, nv, s, d)
return senses
def forward(self, input_ids):
"""Produce a single sentence embedding by mean-pooling transformed senses.
Returns:
embedding: Tensor of shape (bs, d)
"""
assert not torch.isnan(input_ids).any(), "Input IDs contain NaN values"
senses = self.get_model_output(input_ids) # (bs, nv, s, d)
transformed_senses = senses @ self.transformation_matrix # (bs, nv, s, d)
embedding = transformed_senses.mean(dim=(1, 2)) # (bs, d)
return embedding
class MLP(nn.Module):
"""Feed-forward block: linear -> activation -> linear -> dropout.
Uses HuggingFace's Conv1D (equivalent to a linear layer applied
along the last dimension) for compatibility with GPT-2 style configs.
"""
def __init__(self, embed_dim, intermediate_dim, out_dim, config):
super().__init__()
self.c_fc = Conv1D(intermediate_dim, embed_dim)
self.c_proj = Conv1D(out_dim, intermediate_dim)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class NoMixBlock(nn.Module):
"""Transformer-style block *without* attention (no token mixing).
Applies two residual sub-layers with layer normalization and dropout,
where the only transformation is an MLP — tokens are processed independently.
"""
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.mlp = MLP(config.n_embd, config.n_embd * 4, config.n_embd, config)
self.resid_dropout1 = nn.Dropout(config.resid_pdrop)
self.resid_dropout2 = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states, residual):
residual = self.resid_dropout1(hidden_states) + residual
hidden_states = self.ln_1(residual)
mlp_out = self.mlp(hidden_states)
residual = self.resid_dropout2(mlp_out) + residual
hidden_states = self.ln_2(residual)
return hidden_states
class SenseNetwork(nn.Module):
"""Decomposes token embeddings into multiple sense vectors.
Each token is mapped from (d,) to (num_senses, d) via a NoMixBlock
followed by an MLP that expands the embedding dimension and reshapes.
Input: (bs, s, d)
Output: (bs, num_senses, s, d)
"""
def __init__(self, config, device=None, dtype=None):
super().__init__()
self.num_senses = config.num_senses
self.n_embd = config.n_embd
self.dropout = nn.Dropout(config.embd_pdrop)
self.block = NoMixBlock(config)
self.ln = nn.LayerNorm(self.n_embd, eps=config.layer_norm_epsilon)
self.final_mlp = MLP(
embed_dim=config.n_embd,
intermediate_dim=config.sense_intermediate_scale * config.n_embd,
out_dim=config.n_embd * config.num_senses,
config=config,
)
def forward(self, input_embeds):
residual = self.dropout(input_embeds)
hidden_states = self.ln(residual)
hidden_states = self.block(hidden_states, residual)
senses = self.final_mlp(hidden_states)
bs, s, nvd = senses.shape
# Reshape from (bs, s, num_senses*d) -> (bs, num_senses, s, d)
return senses.reshape(bs, s, self.num_senses, self.n_embd).transpose(1, 2)
|