| """ |
| HuggingFace-compatible vec2vec implementation for embedding translation. |
| Based on: "Harnessing the Universal Geometry of Embeddings" (arXiv:2505.12540) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Dict, Optional, List |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import ModelOutput |
|
|
|
|
| |
| |
| |
|
|
| class Vec2VecConfig(PretrainedConfig): |
| """Configuration for Vec2Vec model.""" |
| |
| model_type = "vec2vec" |
| |
| def __init__( |
| self, |
| encoder_names: List[str] = None, |
| encoder_dims: List[int] = None, |
| d_adapter: int = 1024, |
| d_hidden: int = 1024, |
| d_transform: int = 1024, |
| adapter_depth: int = 3, |
| transform_depth: int = 4, |
| disc_dim: int = 1024, |
| disc_depth: int = 5, |
| weight_init: str = "kaiming", |
| norm_style: str = "batch", |
| normalize_embeddings: bool = True, |
| |
| loss_coefficient_rec: float = 1.0, |
| loss_coefficient_vsp: float = 1.0, |
| loss_coefficient_cc_trans: float = 10.0, |
| loss_coefficient_cc_vsp: float = 10.0, |
| loss_coefficient_cc_rec: float = 0.0, |
| loss_coefficient_gen: float = 1.0, |
| loss_coefficient_latent_gen: float = 1.0, |
| loss_coefficient_similarity_gen: float = 0.0, |
| loss_coefficient_disc: float = 1.0, |
| loss_coefficient_r1_penalty: float = 0.0, |
| |
| noise_level: float = 0.0, |
| max_grad_norm: float = 1000.0, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.encoder_names = encoder_names or ["model_a", "model_b"] |
| self.encoder_dims = encoder_dims or [768, 768] |
| self.d_adapter = d_adapter |
| self.d_hidden = d_hidden |
| self.d_transform = d_transform |
| self.adapter_depth = adapter_depth |
| self.transform_depth = transform_depth |
| self.disc_dim = disc_dim |
| self.disc_depth = disc_depth |
| self.weight_init = weight_init |
| self.norm_style = norm_style |
| self.normalize_embeddings = normalize_embeddings |
| |
| self.loss_coefficient_rec = loss_coefficient_rec |
| self.loss_coefficient_vsp = loss_coefficient_vsp |
| self.loss_coefficient_cc_trans = loss_coefficient_cc_trans |
| self.loss_coefficient_cc_vsp = loss_coefficient_cc_vsp |
| self.loss_coefficient_cc_rec = loss_coefficient_cc_rec |
| self.loss_coefficient_gen = loss_coefficient_gen |
| self.loss_coefficient_latent_gen = loss_coefficient_latent_gen |
| self.loss_coefficient_similarity_gen = loss_coefficient_similarity_gen |
| self.loss_coefficient_disc = loss_coefficient_disc |
| self.loss_coefficient_r1_penalty = loss_coefficient_r1_penalty |
| self.noise_level = noise_level |
| self.max_grad_norm = max_grad_norm |
|
|
| def get_encoder_dims_dict(self) -> Dict[str, int]: |
| """Return encoder dimensions as a dictionary.""" |
| return dict(zip(self.encoder_names, self.encoder_dims)) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class Vec2VecOutput(ModelOutput): |
| """Output type for Vec2Vec forward pass.""" |
| loss: Optional[torch.FloatTensor] = None |
| reconstructions: Optional[Dict[str, torch.Tensor]] = None |
| translations: Optional[Dict[str, Dict[str, torch.Tensor]]] = None |
| latents: Optional[Dict[str, torch.Tensor]] = None |
| metrics: Optional[Dict[str, float]] = None |
|
|
|
|
| |
| |
| |
|
|
| def add_residual(input_x: torch.Tensor, x: torch.Tensor) -> torch.Tensor: |
| """Add residual connection with dimension matching.""" |
| if input_x.shape[1] < x.shape[1]: |
| padding = torch.zeros(x.shape[0], x.shape[1] - input_x.shape[1], device=x.device) |
| input_x = torch.cat([input_x, padding], dim=1) |
| elif input_x.shape[1] > x.shape[1]: |
| input_x = input_x[:, :x.shape[1]] |
| return x + input_x |
|
|
|
|
| class MLPWithResidual(nn.Module): |
| """MLP with residual connections.""" |
| |
| def __init__( |
| self, |
| depth: int, |
| in_dim: int, |
| hidden_dim: int, |
| out_dim: int, |
| norm_style: str = "batch", |
| weight_init: str = "kaiming", |
| ): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| norm_layer = nn.BatchNorm1d if norm_style == "batch" else nn.LayerNorm |
|
|
| for layer_idx in range(depth): |
| if layer_idx == 0: |
| h_dim = out_dim if depth == 1 else hidden_dim |
| self.layers.append(nn.Sequential(nn.Linear(in_dim, h_dim), nn.SiLU())) |
| elif layer_idx < depth - 1: |
| self.layers.append(nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.SiLU(), |
| norm_layer(hidden_dim), |
| nn.Dropout(p=0.1), |
| )) |
| else: |
| self.layers.append(nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.Dropout(p=0.1), |
| nn.SiLU(), |
| nn.Linear(hidden_dim, out_dim), |
| )) |
| self._initialize_weights(weight_init) |
| |
| def _initialize_weights(self, weight_init: str): |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| if weight_init == "kaiming": |
| nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in", nonlinearity="relu") |
| elif weight_init == "xavier": |
| nn.init.xavier_normal_(module.weight) |
| elif weight_init == "orthogonal": |
| nn.init.orthogonal_(module.weight) |
| module.bias.data.fill_(0) |
| elif isinstance(module, nn.BatchNorm1d): |
| nn.init.normal_(module.weight, mean=1.0, std=0.02) |
| nn.init.normal_(module.bias, mean=0.0, std=0.02) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.constant_(module.bias, 0) |
| nn.init.constant_(module.weight, 1.0) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for layer in self.layers: |
| input_x = x |
| x = layer(x) |
| x = add_residual(input_x, x) |
| return x |
|
|
|
|
| class Discriminator(nn.Module): |
| """Discriminator network for adversarial training.""" |
| |
| def __init__( |
| self, |
| latent_dim: int, |
| hidden_dim: int = 1024, |
| depth: int = 5, |
| weight_init: str = "kaiming", |
| ): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| |
| if depth >= 2: |
| layers = [nn.Linear(latent_dim, hidden_dim), nn.Dropout(0.0)] |
| for _ in range(depth - 2): |
| layers.extend([ |
| nn.SiLU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.Dropout(0.0), |
| ]) |
| layers.extend([nn.SiLU(), nn.Linear(hidden_dim, 1)]) |
| self.layers.append(nn.Sequential(*layers)) |
| else: |
| self.layers.append(nn.Linear(latent_dim, 1)) |
| |
| self._initialize_weights(weight_init) |
| |
| def _initialize_weights(self, weight_init: str): |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| if weight_init == "kaiming": |
| nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in", nonlinearity="relu") |
| elif weight_init == "xavier": |
| nn.init.xavier_normal_(module.weight) |
| elif weight_init == "orthogonal": |
| nn.init.orthogonal_(module.weight) |
| module.bias.data.fill_(0) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.constant_(module.bias, 0) |
| nn.init.constant_(module.weight, 1.0) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class Vec2VecModel(PreTrainedModel): |
| """ |
| Vec2Vec model for embedding translation between different spaces. |
| |
| Architecture: |
| Input -> In Adapter -> Transform -> Out Adapter -> Output |
| """ |
| |
| config_class = Vec2VecConfig |
| all_tied_weights_keys = {} |
| |
| def __init__(self, config: Vec2VecConfig): |
| super().__init__(config) |
| self.config = config |
| encoder_dims = config.get_encoder_dims_dict() |
| |
| |
| self.transform = MLPWithResidual( |
| depth=config.transform_depth, |
| in_dim=config.d_adapter, |
| hidden_dim=config.d_transform, |
| out_dim=config.d_adapter, |
| norm_style=config.norm_style, |
| weight_init=config.weight_init, |
| ) |
| |
| |
| self.in_adapters = nn.ModuleDict() |
| self.out_adapters = nn.ModuleDict() |
| |
| for name, dim in encoder_dims.items(): |
| self.in_adapters[name] = MLPWithResidual( |
| config.adapter_depth, dim, config.d_hidden, config.d_adapter, |
| config.norm_style, config.weight_init, |
| ) |
| self.out_adapters[name] = MLPWithResidual( |
| config.adapter_depth, config.d_adapter, config.d_hidden, dim, |
| config.norm_style, config.weight_init, |
| ) |
| |
| |
| self.discriminators = nn.ModuleDict() |
| for name, dim in encoder_dims.items(): |
| self.discriminators[name] = Discriminator( |
| dim, config.disc_dim, config.disc_depth, config.weight_init |
| ) |
| self.discriminators["latent"] = Discriminator( |
| config.d_adapter, config.disc_dim, config.disc_depth, config.weight_init |
| ) |
| |
| self.post_init() |
| |
| def add_encoder(self, name: str, dim: int, overwrite: bool = False): |
| """Add a new encoder to the model.""" |
| if name in self.in_adapters and not overwrite: |
| print(f"Encoder {name} already exists, skipping...") |
| return |
| |
| self.in_adapters[name] = MLPWithResidual( |
| self.config.adapter_depth, dim, self.config.d_hidden, self.config.d_adapter, |
| self.config.norm_style, self.config.weight_init, |
| ) |
| self.out_adapters[name] = MLPWithResidual( |
| self.config.adapter_depth, self.config.d_adapter, self.config.d_hidden, dim, |
| self.config.norm_style, self.config.weight_init, |
| ) |
| self.discriminators[name] = Discriminator( |
| dim, self.config.disc_dim, self.config.disc_depth, self.config.weight_init |
| ) |
| |
| |
| if name not in self.config.encoder_names: |
| self.config.encoder_names.append(name) |
| self.config.encoder_dims.append(dim) |
| |
| def _get_latent(self, emb: torch.Tensor, encoder_name: str) -> torch.Tensor: |
| """Get latent representation from embedding.""" |
| z = self.in_adapters[encoder_name](emb) |
| return self.transform(z) |
| |
| def _decode(self, latent: torch.Tensor, encoder_name: str) -> torch.Tensor: |
| """Decode latent to target embedding space.""" |
| out = self.out_adapters[encoder_name](latent) |
| if self.config.normalize_embeddings: |
| out = F.normalize(out, p=2, dim=1) |
| return out |
| |
| def translate(self, embeddings: torch.Tensor, src: str, tgt: str) -> torch.Tensor: |
| """Translate embeddings from source to target space.""" |
| latent = self._get_latent(embeddings, src) |
| return self._decode(latent, tgt) |
| |
| def forward( |
| self, |
| inputs: Dict[str, torch.Tensor], |
| noise_level: float = None, |
| return_latents: bool = False, |
| ) -> Vec2VecOutput: |
| """ |
| Forward pass computing reconstructions and translations. |
| |
| Args: |
| inputs: Dict mapping encoder names to embeddings |
| noise_level: Optional noise for training |
| return_latents: Whether to return latent representations |
| """ |
| noise_level = noise_level if noise_level is not None else self.config.noise_level |
| |
| reconstructions = {} |
| translations = {} |
| latents = {} |
| |
| for src_name, emb in inputs.items(): |
| |
| if self.training and noise_level > 0.0: |
| emb = emb + torch.randn_like(emb) * noise_level |
| emb = F.normalize(emb, p=2, dim=1) |
| |
| latent = self._get_latent(emb, src_name) |
| if return_latents: |
| latents[src_name] = latent |
| |
| for tgt_name in inputs.keys(): |
| decoded = self._decode(latent, tgt_name) |
| if tgt_name == src_name: |
| reconstructions[src_name] = decoded |
| else: |
| if tgt_name not in translations: |
| translations[tgt_name] = {} |
| translations[tgt_name][src_name] = decoded |
| |
| return Vec2VecOutput( |
| reconstructions=reconstructions, |
| translations=translations, |
| latents=latents if return_latents else None, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def reconstruction_loss(inputs: Dict[str, torch.Tensor], recons: Dict[str, torch.Tensor]) -> torch.Tensor: |
| """Reconstruction loss (1 - cosine similarity).""" |
| loss = sum(1 - F.cosine_similarity(inputs[k], recons[k], dim=1).mean() for k in inputs) |
| return loss / len(inputs) |
|
|
|
|
| def translation_loss(inputs: Dict[str, torch.Tensor], translations: Dict[str, Dict[str, torch.Tensor]]) -> torch.Tensor: |
| """Translation loss (1 - cosine similarity).""" |
| loss = 0.0 |
| count = 0 |
| for tgt, emb in inputs.items(): |
| for trans in translations[tgt].values(): |
| loss += 1 - F.cosine_similarity(emb, trans, dim=1).mean() |
| count += 1 |
| return loss / max(count, 1) |
|
|
|
|
| def vsp_loss(inputs: Dict[str, torch.Tensor], translations: Dict[str, Dict[str, torch.Tensor]]) -> torch.Tensor: |
| """Vector Space Preservation (VSP) loss.""" |
| loss = 0.0 |
| count = 0 |
| EPS = 1e-10 |
| |
| for out_name in inputs: |
| for in_name in translations[out_name]: |
| B = F.normalize(inputs[out_name].detach(), p=2, dim=1) |
| A = F.normalize(translations[out_name][in_name], p=2, dim=1) |
| |
| in_sims = B @ B.T |
| out_sims = A @ A.T |
| out_sims_reflected = A @ B.T |
| |
| loss += (in_sims - out_sims).abs().mean() |
| loss += (in_sims - out_sims_reflected).abs().mean() |
| count += 1 |
| |
| return loss / max(count, 1) |
|
|
|
|
| from typing import Optional, Union, List, Dict |
| from transformers import AutoModel, AutoTokenizer |
| from .base_tokenizer import BaseSequenceTokenizer |
| from .supported_models import all_presets_with_paths |
|
|
| from pooler import Pooler |
|
|
|
|
| presets = { |
| 'vec2vec-ESM2-8-ESM2-35': 'Synthyra/ESM2-8-ESM2-35-sequence-sequence', |
| 'vec2vec-ESM2-8-ESM2-150': 'Synthyra/ESM2-8-ESM2-150-sequence-sequence', |
| 'vec2vec-ESM2-8-ESM2-650': 'Synthyra/ESM2-8-ESM2-650-sequence-sequence', |
| 'vec2vec-ESM2-8-ESM2-3B': 'Synthyra/ESM2-8-ESM2-3B-sequence-sequence', |
| 'vec2vec-ESM2-35-ESM2-150': 'Synthyra/ESM2-35-ESM2-150-sequence-sequence', |
| 'vec2vec-ESM2-35-ESM2-650': 'Synthyra/ESM2-35-ESM2-650-sequence-sequence', |
| 'vec2vec-ESM2-35-ESM2-3B': 'Synthyra/ESM2-35-ESM2-3B-sequence-sequence', |
| 'vec2vec-ESM2-150-ESM2-650': 'Synthyra/ESM2-150-ESM2-650-sequence-sequence', |
| 'vec2vec-ESM2-150-ESM2-3B': 'Synthyra/ESM2-150-ESM2-3B-sequence-sequence', |
| 'vec2vec-ESM2-650-ESM2-3B': 'Synthyra/ESM2-650-ESM2-3B-sequence-sequence', |
| } |
|
|
|
|
| class Vec2VecTokenizerWrapper(BaseSequenceTokenizer): |
| def __init__(self, tokenizer: AutoTokenizer): |
| super().__init__(tokenizer) |
|
|
| def __call__(self, sequences: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]: |
| if isinstance(sequences, str): |
| sequences = [sequences] |
| kwargs.setdefault('return_tensors', 'pt') |
| kwargs.setdefault('padding', 'longest') |
| kwargs.setdefault('add_special_tokens', True) |
| tokenized = self.tokenizer(sequences, **kwargs) |
| return tokenized |
|
|
|
|
| class Vec2VecForEmbedding(nn.Module): |
| def __init__( |
| self, |
| config: Vec2VecConfig, |
| base_model: AutoModel, |
| vec2vec_model: Vec2VecModel, |
| model_name_a: str, |
| model_name_b: str, |
| ): |
| super().__init__() |
| self.base_model = base_model |
| self.vec2vec_model = vec2vec_model |
| self.config = config |
| self.pooler = Pooler(['mean', 'var']) |
| self.model_name_a = model_name_a |
| self.model_name_b = model_name_b |
| self.normalize = config.normalize_embeddings |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = False, |
| **kwargs, |
| ) -> torch.Tensor: |
| |
| base_state = self.base_model(input_ids, attention_mask=attention_mask).last_hidden_state |
| base_vec = self.pooler(base_state, attention_mask=attention_mask) |
| if self.normalize: |
| base_vec = F.normalize(base_vec, p=2, dim=1) |
| translated_ab = self.vec2vec_model.translate(base_vec, src=self.model_name_a, tgt=self.model_name_b) |
| return translated_ab |
|
|
|
|
| def get_vec2vec_tokenizer(preset: str, model_path: str = None): |
| |
| path = model_path or all_presets_with_paths[preset] |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
| except: |
| model = AutoModel.from_pretrained(path, trust_remote_code=True) |
| tokenizer = AutoTokenizer.from_pretrained(model.config.tokenizer_name) |
| return Vec2VecTokenizerWrapper(tokenizer) |
|
|
|
|
| def build_vec2vec_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs): |
| if masked_lm: |
| raise ValueError("Masked LM is not supported for Vec2VecForEmbedding") |
| else: |
| model_path = model_path or presets[preset] |
| config = Vec2VecConfig.from_pretrained(model_path) |
| encoder_names = config.encoder_names |
| encoder_dims = config.encoder_dims |
|
|
| if encoder_dims[0] >= encoder_dims[1]: |
| model_name_a = encoder_names[0] |
| model_name_b = encoder_names[1] |
| else: |
| model_name_a = encoder_names[1] |
| model_name_b = encoder_names[0] |
|
|
| base_model = AutoModel.from_pretrained(all_presets_with_paths[model_name_a], dtype=dtype, trust_remote_code=True) |
| base_tokenizer = base_model.tokenizer |
| vec2vec_model = Vec2VecModel(config).from_pretrained(model_path) |
| model = Vec2VecForEmbedding(config, base_model, vec2vec_model, model_name_a, model_name_b) |
| tokenizer = Vec2VecTokenizerWrapper(base_tokenizer) |
| return model, tokenizer |
|
|
|
|
| def get_vec2vec_for_training(preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False): |
| raise ValueError("Vec2VecForTraining is not supported yet") |
|
|
|
|
| if __name__ == '__main__': |
| |
| model, tokenizer = build_vec2vec_model('ESM2-8-ESM2-35') |
| print(model) |
| print(tokenizer) |
| print(tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL')) |
|
|