| """ |
| Copyright (c) 2023, salesforce.com, inc. |
| All rights reserved. |
| SPDX-License-Identifier: BSD-3-Clause |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| """ |
| import contextlib |
| import logging |
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| from torch.cuda.amp import autocast as autocast |
| from torch.nn import functional as F |
|
|
| from lavis.models.blip2_models.blip2 import ( |
| disabled_train, |
| ) |
| from lavis.models.blip_models.blip_outputs import BlipOutput |
| from model.blip2 import Blip2Base |
| from model.help_funcs import pad_and_concat |
| from model.dist_funs import pl_concat_all_gather |
|
|
|
|
|
|
| class Blip2Qformer(Blip2Base): |
| """ |
| BLIP2 first-stage model with Q-former and ViT. |
| Supported model types: |
| - pretrained: pretrained model with vit-g |
| - pretrain_vitL: pretrained model with vit-large |
| - coco: fintuned model on coco |
| Usage: |
| >>> from lavis.models import load_model |
| >>> model = load_model("blip2", "pretrain") |
| """ |
| def __init__( |
| self, |
| ptm, |
| lm, |
| bert_name, |
| plm_name, |
| temperature, |
| plm_tune='freeze', |
| num_query_token=32, |
| cross_attention_freq=2, |
| embed_dim=256, |
| pool_size=0, |
| load_4bit=False, |
| ): |
| super().__init__() |
| self.ptm = ptm |
| self.lm = lm |
| self.pool_size = pool_size |
|
|
| |
| self.plm_tokenizer, self.plm, self.ln_layer = self.init_protein_encoder(plm_name, load_4bit) |
| self.plm_tune = plm_tune |
| if plm_tune == 'freeze': |
| for name, param in self.plm.named_parameters(): |
| param.requires_grad = False |
| self.plm = self.plm.eval() |
| self.plm.train = disabled_train |
| logging.info("freeze plm") |
| |
| self.tokenizer, self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.plm.num_features, cross_attention_freq) |
| self.Qformer.resize_token_embeddings(len(self.tokenizer)) |
| state_dict = self.Qformer.state_dict() |
| for name, param in self.Qformer.named_parameters(): |
| if "_query" in name: |
| key_orig = name.replace("_query", "") |
| param.data.copy_(state_dict[key_orig]) |
|
|
| self.prot_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) |
| self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) |
| self.ptm_head = nn.Linear(self.Qformer.config.hidden_size, 2) |
| self.temperature = temperature |
|
|
|
|
|
|
| def contrast_global(self, features_graph, features_text, features_graph_all, features_text_all, return_sim=False): |
| ''' |
| features_graph: shape = [B, num_qs, D] |
| features_text: shape = [B, D] |
| features_text_all: shape = [B * num_gpus, D] |
| features_graph_all: shape = [B * num_gpus, num_qs, D] |
| ''' |
| bs = features_graph.size(0) |
|
|
| |
| sim_q2t = (features_graph.unsqueeze(1) @ features_text_all.unsqueeze(-1)).squeeze() |
| sim_g2t, _ = sim_q2t.max(-1) |
|
|
| logits_per_graph = sim_g2t / self.temperature |
| |
| sim_t2q = (features_text.unsqueeze(1).unsqueeze(1) @ features_graph_all.permute(0, 2, 1)).squeeze() |
| sim_t2g, _ = sim_t2q.max(-1) |
| logits_per_text = sim_t2g / self.temperature |
|
|
| |
| rank = dist.get_rank() |
| labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device) |
| |
| |
| loss_graph = F.cross_entropy(logits_per_graph, labels) |
| loss_text = F.cross_entropy(logits_per_text, labels) |
| loss = (loss_graph + loss_text) / 2 |
|
|
| if return_sim: |
| |
| return logits_per_graph, logits_per_text, loss |
| else: |
| return loss |
|
|
| def forward(self, batch): |
| prot_batch, text_batch = batch |
| |
| |
| prot_embeds = self.plm(**prot_batch, return_dict=True) |
| prot_embeds = prot_embeds.last_hidden_state |
| if self.plm_tune == 'freeze': |
| prot_embeds = prot_embeds.detach() |
| batch_size = prot_embeds.shape[0] |
| device = prot_embeds.device |
|
|
| prot_embeds = self.ln_layer(prot_embeds) |
| |
| query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1) |
| query_output = self.Qformer.bert( |
| query_embeds=query_tokens, |
| encoder_hidden_states=prot_embeds, |
| encoder_attention_mask=prot_batch.attention_mask, |
| use_cache=True, |
| return_dict=True, |
| ) |
| prot_feats = self.prot_proj(query_output.last_hidden_state) |
| text_output = self.Qformer.bert(**text_batch, return_dict=True) |
| text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :]) |
| |
| text_feats, prot_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(prot_feats, p=2, dim=-1) |
| text_feats_all, prot_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(prot_feats) |
| sim_p2t, sim_t2p, loss_ptc = self.contrast_global(prot_feats, text_feats, prot_feats_all, text_feats_all, return_sim=True) |
|
|
| |
| loss_ptm = 0 |
| if self.ptm: |
| |
| prot_embeds_world = pl_concat_all_gather(prot_embeds) |
| prot_mask_world = pl_concat_all_gather(prot_batch.attention_mask) |
| text_ids_world = pl_concat_all_gather(text_batch.input_ids) |
| text_mask_world = pl_concat_all_gather(text_batch.attention_mask) |
| with torch.no_grad(): |
| rank = dist.get_rank() |
| weights_t2p = F.softmax(sim_t2p, dim=1) + 1e-4 |
| weights_t2p[:, rank * batch_size : rank * batch_size + batch_size].fill_diagonal_(0) |
| |
| weights_p2t = F.softmax(sim_p2t, dim=1) + 1e-4 |
| weights_p2t[:, rank * batch_size : rank * batch_size + batch_size].fill_diagonal_(0) |
| |
|
|
| |
| prot_embeds_neg = [] |
| prot_mask_neg = [] |
| for b in range(batch_size): |
| neg_idx = torch.multinomial(weights_t2p[b], 1).item() |
| prot_embeds_neg.append(prot_embeds_world[neg_idx]) |
| prot_mask_neg.append(prot_mask_world[neg_idx]) |
| |
| prot_embeds_neg = torch.stack(prot_embeds_neg, dim=0) |
| prot_mask_neg = torch.stack(prot_mask_neg, dim=0) |
|
|
| |
| text_ids_neg = [] |
| text_mask_neg = [] |
| for b in range(batch_size): |
| neg_idx = torch.multinomial(weights_p2t[b], 1).item() |
| text_ids_neg.append(text_ids_world[neg_idx]) |
| text_mask_neg.append(text_mask_world[neg_idx]) |
|
|
| text_ids_neg = torch.stack(text_ids_neg, dim=0) |
| text_mask_neg = torch.stack(text_mask_neg, dim=0) |
|
|
| text_ids_all = torch.cat( |
| [text_batch.input_ids, text_batch.input_ids, text_ids_neg], dim=0 |
| ) |
| text_mask_all = torch.cat( |
| [text_batch.attention_mask, text_batch.attention_mask, text_mask_neg], dim=0, |
| ) |
|
|
| query_tokens_ptm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) |
| query_mask_ptm = torch.ones(query_tokens_ptm.size()[:-1], dtype=torch.long, device=device) |
| attention_mask_all = torch.cat([query_mask_ptm, text_mask_all], dim=1) |
|
|
| prot_embeds_all = torch.cat([prot_embeds, prot_embeds_neg, prot_embeds], dim=0) |
| prot_mask_all = torch.cat([prot_batch.attention_mask, prot_mask_neg, prot_batch.attention_mask], dim=0) |
| |
| output_ptm = self.Qformer.bert( |
| text_ids_all, |
| query_embeds=query_tokens_ptm, |
| attention_mask=attention_mask_all, |
| encoder_hidden_states=prot_embeds_all, |
| encoder_attention_mask=prot_mask_all, |
| return_dict=True, |
| ) |
|
|
| pl_embeddings = output_ptm.last_hidden_state[:, : query_tokens_ptm.size(1), :] |
| pl_output = self.ptm_head(pl_embeddings) |
| logits = pl_output.mean(dim=1) |
|
|
| ptm_labels = torch.cat( |
| [torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)], |
| dim=0, |
| ).to(device) |
| loss_ptm = F.cross_entropy(logits, ptm_labels) |
|
|
| |
| loss_lm = 0 |
| if self.lm: |
| |
| enable_autocast = query_output.past_key_values[0][0].dtype == torch.float16 |
| with torch.cuda.amp.autocast(enable_autocast, dtype=torch.float32): |
| decoder_input_ids = text_batch.input_ids.clone() |
| decoder_input_ids[:, 0] = self.tokenizer.bos_token_id |
| labels = decoder_input_ids.masked_fill( |
| decoder_input_ids == self.tokenizer.pad_token_id, -100 |
| ) |
| query_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=device) |
| attention_mask = torch.cat([query_mask, text_batch.attention_mask], dim=1) |
| lm_output = self.Qformer( |
| decoder_input_ids, |
| attention_mask=attention_mask, |
| past_key_values=query_output.past_key_values, |
| return_dict=True, |
| labels=labels, |
| ) |
| loss_lm = lm_output.loss |
|
|
| return BlipOutput( |
| loss=loss_ptc + loss_ptm + loss_lm, |
| loss_itc=loss_ptc, |
| loss_itm=loss_ptm, |
| loss_lm=loss_lm, |
| ) |
|
|
| def prot_forward(self, prot_batch): |
| prot_embeds = self.plm(**prot_batch, return_dict=True) |
| prot_embeds = prot_embeds.last_hidden_state |
| if self.plm_tune == 'freeze': |
| prot_embeds = prot_embeds.detach() |
| prot_embeds = self.ln_layer(prot_embeds) |
| query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1) |
| query_output = self.Qformer.bert( |
| query_embeds=query_tokens, |
| encoder_hidden_states=prot_embeds, |
| encoder_attention_mask=prot_batch.attention_mask, |
| use_cache=True, |
| return_dict=True, |
| ) |
| prot_feats = self.prot_proj(query_output.last_hidden_state) |
| prot_feats = F.normalize(prot_feats, dim=-1, p=2) |
| return prot_feats, prot_embeds |
|
|
| def text_forward(self, text_batch): |
| text_output = self.Qformer.bert(**text_batch, return_dict=True) |
| text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :] ) |
| text_feats = F.normalize(text_feats, dim=-1, p=2) |
| return text_feats |
| |
| def compute_ptm(self, prot_embeds, prot_mask, text_ids, text_mask): |
| batch_size = prot_embeds.size(0) |
| device = prot_embeds.device |
| query_tokens = self.query_tokens.expand(batch_size, -1, -1) |
| query_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(device) |
| attention_mask = torch.cat([query_mask, text_mask], dim=1) |
| output_ptm = self.Qformer.bert( |
| text_ids, |
| query_embeds=query_tokens, |
| attention_mask=attention_mask, |
| encoder_hidden_states=prot_embeds, |
| encoder_attention_mask=prot_mask, |
| return_dict=True, |
| ) |
| pl_embeddings = output_ptm.last_hidden_state[:, : query_tokens.size(1), :] |
| ptm_logit = self.ptm_head(pl_embeddings).mean(dim=1) |
| |
| ptm_logit = ptm_logit[:, 1] |
| return ptm_logit |
|
|
|
|
|
|
|
|