File size: 3,262 Bytes
850592d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
"""
Latent Attention Pooling implementation for LLM2Vec4CXR.
Vendored to make the model self-contained (no external llm2vec dependency required).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class LatentAttentionPooling(nn.Module):
"""
Latent attention pooling layer that uses a trainable latent dictionary
to aggregate token embeddings into a fixed-size representation.
"""
def __init__(self, d_model, num_latents=512, num_heads=8):
"""
Args:
d_model: Hidden size of the model (e.g., 2048 for Llama-7B)
num_latents: Number of learnable latent vectors (default: 512)
num_heads: Number of attention heads (default: 8)
"""
super().__init__()
self.num_latents = num_latents
self.d_model = d_model
# Trainable latent dictionary (used as both keys and values)
self.latents = nn.Parameter(torch.randn(num_latents, d_model))
# Multihead attention layer
# batch_first=True means input shape is (batch, seq_length, hidden_size)
self.multihead_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=num_heads,
batch_first=True
)
# Simple MLP: Linear -> GELU -> Linear
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Linear(d_model, d_model)
)
def forward(self, hidden_states, attention_mask=None):
"""
Apply latent attention pooling to hidden states.
Args:
hidden_states: Token embeddings of shape (batch_size, seq_len, d_model)
attention_mask: Optional mask of shape (batch_size, seq_len)
Returns:
Pooled embeddings of shape (batch_size, d_model)
"""
batch_size, seq_len, d_model = hidden_states.shape
device = hidden_states.device
# Ensure the module is on the same device as input
if next(self.parameters()).device != device:
self.to(device)
# Expand latents to match batch size: (batch_size, num_latents, d_model)
latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
# Apply multihead attention
# Use hidden_states as queries and latent dictionary as keys/values
# This computes: O = softmax((QK^T)/√d)V
attn_output, _ = self.multihead_attn(
query=hidden_states,
key=latents,
value=latents
)
# Apply MLP to attention output
mlp_output = self.mlp(attn_output)
# Mean pool over sequence dimension
if attention_mask is not None:
# Mask out padding tokens before pooling
mask_expanded = attention_mask.unsqueeze(-1).expand(mlp_output.size()).float()
sum_embeddings = torch.sum(mlp_output * mask_expanded, dim=1)
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
pooled = sum_embeddings / sum_mask
else:
# Simple mean pooling if no mask provided
pooled = mlp_output.mean(dim=1)
return pooled
|