lukeingawesome commited on
Commit
f1a89e8
·
verified ·
1 Parent(s): d3e466f

Update modeling_llm2vec4cxr.py with helper methods and vendored pooling

Browse files
Files changed (1) hide show
  1. modeling_llm2vec4cxr.py +77 -1
modeling_llm2vec4cxr.py CHANGED
@@ -3,9 +3,12 @@ Custom model class for LLM2Vec4CXR that properly handles latent attention poolin
3
  """
4
 
5
  from llm2vec.models.bidirectional_llama import LlamaBiModel
6
- from llm2vec.pooling import LatentAttentionPooling
 
 
7
  import torch
8
  import torch.nn as nn
 
9
 
10
 
11
  class LLM2Vec4CXRModel(LlamaBiModel):
@@ -49,6 +52,79 @@ class LLM2Vec4CXRModel(LlamaBiModel):
49
 
50
  return outputs.last_hidden_state
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  # Register the model for auto loading
54
  from transformers import AutoModel
 
3
  """
4
 
5
  from llm2vec.models.bidirectional_llama import LlamaBiModel
6
+ # from llm2vec.pooling import LatentAttentionPooling
7
+ from .pooling_latent import LatentAttentionPooling
8
+ from transformers import AutoTokenizer
9
  import torch
10
  import torch.nn as nn
11
+ import torch.nn.functional as F
12
 
13
 
14
  class LLM2Vec4CXRModel(LlamaBiModel):
 
52
 
53
  return outputs.last_hidden_state
54
 
55
+ # --- Convenience tokenizer (lazy) -------------------------------------
56
+ def _get_tokenizer(self):
57
+ if not hasattr(self, "_hf_tokenizer"):
58
+ tok = AutoTokenizer.from_pretrained(getattr(self.config, "_name_or_path", "lukeingawesome/llm2vec4cxr"))
59
+ if tok.pad_token is None:
60
+ tok.pad_token = tok.eos_token
61
+ tok.padding_side = "left"
62
+ self._hf_tokenizer = tok
63
+ return self._hf_tokenizer
64
+
65
+ # --- Ensure latent_attn follows .to(device/dtype) ----------------------
66
+ def to(self, *args, **kwargs):
67
+ m = super().to(*args, **kwargs)
68
+ if hasattr(self, "latent_attn") and self.latent_attn is not None:
69
+ # Align latent_attn with the base weights' device & dtype
70
+ try:
71
+ device = next(p.device for p in self.parameters() if p is not None)
72
+ dtype = next((p.dtype for p in self.parameters() if p.is_floating_point()), None)
73
+ self.latent_attn = self.latent_attn.to(device=device, dtype=dtype)
74
+ except StopIteration:
75
+ pass
76
+ return m
77
+
78
+ # --- Simple text encoding (no instruction) ----------------------------
79
+ @torch.no_grad()
80
+ def encode_text(self, texts, max_length: int = 512):
81
+ tok = self._get_tokenizer()
82
+ enc = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
83
+ # For simple encoding we embed over all non‑pad tokens
84
+ enc["embed_mask"] = enc["attention_mask"].clone()
85
+ dev = next(self.parameters()).device
86
+ enc = {k: v.to(dev) for k, v in enc.items()}
87
+ return self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], embed_mask=enc["embed_mask"])
88
+
89
+ # --- Instruction/text encoding with separator -------------------------
90
+ def _build_separator_inputs(self, texts, max_length: int, separator: str):
91
+ tok = self._get_tokenizer()
92
+ # Split into [instruction | text]; we embed only the trailing "text" part.
93
+ parts_after_sep = []
94
+ original = []
95
+ for t in texts:
96
+ parts = t.split(separator)
97
+ parts_after_sep.append(parts[1] if len(parts) > 1 else "")
98
+ original.append("".join(parts))
99
+
100
+ tokenized = tok(original, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
101
+ # Build an embed_mask that lights up only the trailing "text" span
102
+ embed_mask = None
103
+ for i, t in enumerate(parts_after_sep):
104
+ sub = tok([t], return_tensors="pt", padding=True, truncation=True, max_length=max_length, add_special_tokens=False)
105
+ m = torch.zeros_like(tokenized["attention_mask"][i])
106
+ if len(sub["input_ids"][0]) > 0:
107
+ m[-len(sub["input_ids"][0]):] = 1
108
+ embed_mask = m.unsqueeze(0) if embed_mask is None else torch.cat([embed_mask, m.unsqueeze(0)], dim=0)
109
+
110
+ tokenized["embed_mask"] = embed_mask
111
+ return tokenized
112
+
113
+ @torch.no_grad()
114
+ def encode_with_separator(self, texts, separator: str = "!@#$%^&*()", max_length: int = 512):
115
+ enc = self._build_separator_inputs(texts, max_length=max_length, separator=separator)
116
+ dev = next(self.parameters()).device
117
+ enc = {k: v.to(dev) for k, v in enc.items()}
118
+ return self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], embed_mask=enc["embed_mask"])
119
+
120
+ # --- One‑liner cosine similarity over instruction+text ----------------
121
+ @torch.no_grad()
122
+ def compute_similarities(self, query_text: str, candidate_texts, separator: str = "!@#$%^&*()", max_length: int = 512):
123
+ all_texts = [query_text] + list(candidate_texts)
124
+ embs = self.encode_with_separator(all_texts, separator=separator, max_length=max_length)
125
+ # embs: [N, 2048]; compare query vs candidates
126
+ return F.cosine_similarity(embs[0], embs[1:], dim=1)
127
+
128
 
129
  # Register the model for auto loading
130
  from transformers import AutoModel