import copy import os from typing import List import torch from torchvision.transforms import transforms from open_flamingo.eval.eval_model import BaseEvalModel from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX from llava.conversation import conv_templates, SeparatorStyle class EvalModelLLAVA(BaseEvalModel): """LLaVA model evaluation. Attributes: model (nn.Module): Underlying Torch model. tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. device: Index of GPU to use, or the string "CPU" """ def __init__(self, model_args): super().__init__(model_args) disable_torch_init() model_path = os.path.expanduser(model_args["model_path"]) model_name = get_model_name_from_path(model_path) self.model, self.image_processor, self.tokenizer, context_len = load_pretrained_model( model_path, model_args.get("model_base"), model_name, pretrained_rob_path=model_args["vision_encoder_pretrained"], dtype=model_args["precision"] ) self.image_processor.do_normalize = False self.normalizer = transforms.Normalize( mean=self.image_processor.image_mean, std=self.image_processor.image_std ) # we need to normalize in the forward pass, so that the threat model is consistent model_args["temperature"] = float(model_args["temperature"]) model_args["num_beams"] = int(model_args["num_beams"]) self.model_args = model_args self.conv_mode = "vicuna_v1" if model_args["precision"] == "float16": self.cast_dtype = torch.float16 elif model_args["precision"] == "float32": self.cast_dtype = torch.float32 else: raise ValueError(f"Unknown dtype: {model_args['precision']}") self.dataset_name = model_args.get("dataset_name") self.stop_str = conv_templates[self.conv_mode].sep if conv_templates[self.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[self.conv_mode].sep2 self.stop_token_id = self.tokenizer.convert_tokens_to_ids(self.stop_str) @torch.no_grad() def get_outputs( self, batch_text, # List[conv object] batch_images: torch.Tensor, min_generation_length: int, max_generation_length: int, **kwargs, ) -> List[str]: assert len(batch_text) == 1, "Only support batch size 1 (yet)" assert 0. <= batch_images.min() and batch_images.max() <= 1., "Images must be in image space" #prompt = batch_text.get_prompt() input_ids = self._prepare_text(batch_text) batch_images = self.normalizer(batch_images) output_ids = self.model.generate( input_ids, images=batch_images.to(dtype=self.cast_dtype, device='cuda', non_blocking=True), do_sample=True if self.model_args["temperature"] > 0 else False, temperature=self.model_args["temperature"], top_p=self.model_args.get("top_p"), num_beams=self.model_args["num_beams"], min_new_tokens=min_generation_length, max_new_tokens=max_generation_length, use_cache=False ) input_token_len = input_ids.shape[1] n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() if n_diff_input_output > 0: print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids") outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() if outputs.endswith(self.stop_str): outputs = outputs[:-len(self.stop_str)] outputs = outputs.strip() return [outputs] def __call__(self, images_unnorm): assert self.input_ids is not None assert self.attention_mask is not None assert self.labels is not None assert 0. <= images_unnorm.min() and images_unnorm.max() <= 1., "Images must be in image space" assert len(images_unnorm.shape) == 4, "[b, c, h, w]" out = self.model( input_ids=self.input_ids, attention_mask=self.attention_mask, past_key_values=self.past_key_values, inputs_embeds=None, labels=self.labels, images=self.normalizer(images_unnorm), ) return out.loss.unsqueeze(0) def set_inputs( self, batch_text, past_key_values: torch.Tensor = None, to_device: bool = False, ): self.input_ids = self._prepare_text(batch_text) context_only = batch_text[0].get_prompt().split("ASSISTANT:")[0] + "ASSISTANT:" context_len = len(self.tokenizer.encode(context_only)) labels = copy.deepcopy(self.input_ids) labels[:, :context_len] = IGNORE_INDEX # labels[labels == self.stop_token_id] = IGNORE_INDEX # print(batch_text[0].get_prompt()) # print(self.tokenizer.decode(labels[labels != IGNORE_INDEX])) self.labels = labels self.attention_mask = self.input_ids.ne(self.tokenizer.pad_token_id) self.past_key_values = past_key_values def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: assert len(batch) == 1, "Only support batch size 1 (yet)" image_tensor = process_images(batch[0], self.image_processor, self.model.config) return image_tensor def _prepare_text(self, convs): input_ids = [ tokenizer_image_token(conv.get_prompt(), self.tokenizer, return_tensors='pt') for conv in convs ] input_ids = torch.stack(input_ids, dim=0).to(device='cuda', non_blocking=True) return input_ids def get_vqa_prompt(self, question, answer=None) -> str: if self.dataset_name == "vizwiz": self.prompt_suffix = "\nWhen the provided information is insufficient, respond with 'Unanswerable'.\nAnswer the question using a single word or phrase." elif self.dataset_name == "textvqa": self.prompt_suffix = "\nAnswer the question using a single word or phrase." elif self.dataset_name == "vqav2": self.prompt_suffix = "\nAnswer the question using a single word or phrase." else: raise ValueError(f"Unknown dataset: {self.dataset_name}") self.prompt_suffix = "" print(f"Unknown dataset: {DATASET_NAME}, using no prompt suffix.") qs = question + self.prompt_suffix if self.model.config.mm_use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs conv = conv_templates[self.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], answer) return conv def get_caption_prompt(self, caption=None) -> str: qs = "Provide a short caption for this image." if self.model.config.mm_use_im_start_end: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_IMAGE_TOKEN + '\n' + qs conv = conv_templates[self.conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], caption) return conv