import os from typing import List, Union import torch from torch import Tensor, nn class ClipTextEncoder(nn.Module): def __init__( self, modelpath: str='openai/clip-vit-large-patch14', # clip-vit-base-patch32 finetune: bool = False, **kwargs ) -> None: super().__init__() from transformers import logging from transformers import AutoModel, AutoTokenizer logging.set_verbosity_error() # Tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" self.tokenizer = AutoTokenizer.from_pretrained(modelpath) self.text_model = AutoModel.from_pretrained(modelpath) # Don't train the model if not finetune: self.text_model.training = False for p in self.text_model.parameters(): p.requires_grad = False # Then configure the model self.max_length = self.tokenizer.model_max_length self.text_encoded_dim = self.text_model.config.text_config.hidden_size def forward(self, texts: List[str]): # get prompt text embeddings text_inputs = self.tokenizer( texts, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.text_model.device) txt_att_mask = text_inputs.attention_mask.to(self.text_model.device) # split into max length Clip can handle if text_input_ids.shape[-1] > self.tokenizer.model_max_length: text_input_ids = text_input_ids[:, :self.tokenizer. model_max_length] # use pooled ouuput if latent dim is two-dimensional # pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim) # text encoder forward, clip must use get_text_features # (batch_Size, seq_length , text_encoded_dim) text_embeddings = self.text_model.text_model(text_input_ids, # attention_mask=txt_att_mask ).last_hidden_state return text_embeddings, txt_att_mask.bool()