from typing import List from PIL import Image import torch from transformers import Blip2Processor, Blip2ForConditionalGeneration from open_flamingo.eval.eval_model import BaseEvalModel from open_flamingo.eval.models.utils import unwrap_model class EvalModel(BaseEvalModel): """BLIP-2 model evaluation. Attributes: model (nn.Module): Underlying Torch model. tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. device: Index of GPU to use, or the string "cpu" """ def __init__(self, model_args): assert ( "processor_path" in model_args and "lm_path" in model_args and "device" in model_args ), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified" self.device = ( int(model_args["device"]) if ("device" in model_args and model_args["device"] >= 0) else "cpu" ) self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) self.model = Blip2ForConditionalGeneration.from_pretrained( model_args["lm_path"] ) self.model.to(self.device) self.model.eval() self.processor.tokenizer.padding_side = "left" def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: """Preprocess images and stack them. Args: batch: A list of lists of images. Returns: A Tensor of shape (batch_size, channels, height, width). """ batch_images = None assert all( len(example) == 1 for example in batch ), "BLIP-2 only supports one image per example" for example in batch: assert len(example) == 1, "BLIP-2 only supports one image per example" batch_images = torch.cat( [ batch_images, self.processor.image_processor(example, return_tensors="pt")[ "pixel_values" ], ] if batch_images is not None else [ self.processor.image_processor(example, return_tensors="pt")[ "pixel_values" ] ], dim=0, ) return batch_images def get_outputs( self, batch_text: List[str], batch_images: List[List[Image.Image]], max_generation_length: int, num_beams: int, length_penalty: float, ) -> List[str]: encodings = self.processor.tokenizer( batch_text, padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) input_ids = encodings["input_ids"] attention_mask = encodings["attention_mask"] with torch.inference_mode(): outputs = unwrap_model(self.model).generate( self._prepare_images(batch_images).to(self.device), input_ids.to(self.device), attention_mask=attention_mask.to(self.device), max_new_tokens=max_generation_length, min_new_tokens=8, num_beams=num_beams, length_penalty=length_penalty, ) return self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) def get_vqa_prompt(self, question, answer=None) -> str: return ( f"Question:{question} Short answer:{answer if answer is not None else ''}" ) def get_caption_prompt(self, caption=None) -> str: return f"A photo of {caption if caption is not None else ''}" def get_classification_prompt(self, class_str=None) -> str: raise NotImplementedError