Spaces:
Runtime error
Runtime error
| import copy | |
| import os | |
| from typing import List | |
| import torch | |
| from torchvision.transforms import transforms | |
| from open_flamingo.eval.eval_model import BaseEvalModel | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path | |
| from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX | |
| from llava.conversation import conv_templates, SeparatorStyle | |
| class EvalModelLLAVA(BaseEvalModel): | |
| """LLaVA model evaluation. | |
| Attributes: | |
| model (nn.Module): Underlying Torch model. | |
| tokenizer (transformers.PreTrainedTokenizer): Tokenizer for model. | |
| device: Index of GPU to use, or the string "CPU" | |
| """ | |
| def __init__(self, model_args): | |
| super().__init__(model_args) | |
| disable_torch_init() | |
| model_path = os.path.expanduser(model_args["model_path"]) | |
| model_name = get_model_name_from_path(model_path) | |
| self.model, self.image_processor, self.tokenizer, context_len = load_pretrained_model( | |
| model_path, model_args.get("model_base"), model_name, pretrained_rob_path=model_args["vision_encoder_pretrained"], | |
| dtype=model_args["precision"] | |
| ) | |
| self.image_processor.do_normalize = False | |
| self.normalizer = transforms.Normalize( | |
| mean=self.image_processor.image_mean, std=self.image_processor.image_std | |
| ) # we need to normalize in the forward pass, so that the threat model is consistent | |
| model_args["temperature"] = float(model_args["temperature"]) | |
| model_args["num_beams"] = int(model_args["num_beams"]) | |
| self.model_args = model_args | |
| self.conv_mode = "vicuna_v1" | |
| if model_args["precision"] == "float16": | |
| self.cast_dtype = torch.float16 | |
| elif model_args["precision"] == "float32": | |
| self.cast_dtype = torch.float32 | |
| else: | |
| raise ValueError(f"Unknown dtype: {model_args['precision']}") | |
| self.dataset_name = model_args.get("dataset_name") | |
| self.stop_str = conv_templates[self.conv_mode].sep if conv_templates[self.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[self.conv_mode].sep2 | |
| self.stop_token_id = self.tokenizer.convert_tokens_to_ids(self.stop_str) | |
| def get_outputs( | |
| self, | |
| batch_text, # List[conv object] | |
| batch_images: torch.Tensor, | |
| min_generation_length: int, | |
| max_generation_length: int, | |
| **kwargs, | |
| ) -> List[str]: | |
| assert len(batch_text) == 1, "Only support batch size 1 (yet)" | |
| assert 0. <= batch_images.min() and batch_images.max() <= 1., "Images must be in image space" | |
| #prompt = batch_text.get_prompt() | |
| input_ids = self._prepare_text(batch_text) | |
| batch_images = self.normalizer(batch_images) | |
| output_ids = self.model.generate( | |
| input_ids, | |
| images=batch_images.to(dtype=self.cast_dtype, device='cuda', non_blocking=True), | |
| do_sample=True if self.model_args["temperature"] > 0 else False, | |
| temperature=self.model_args["temperature"], | |
| top_p=self.model_args.get("top_p"), | |
| num_beams=self.model_args["num_beams"], | |
| min_new_tokens=min_generation_length, | |
| max_new_tokens=max_generation_length, | |
| use_cache=False | |
| ) | |
| input_token_len = input_ids.shape[1] | |
| n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() | |
| if n_diff_input_output > 0: | |
| print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids") | |
| outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | |
| outputs = outputs.strip() | |
| if outputs.endswith(self.stop_str): | |
| outputs = outputs[:-len(self.stop_str)] | |
| outputs = outputs.strip() | |
| return [outputs] | |
| def __call__(self, images_unnorm): | |
| assert self.input_ids is not None | |
| assert self.attention_mask is not None | |
| assert self.labels is not None | |
| assert 0. <= images_unnorm.min() and images_unnorm.max() <= 1., "Images must be in image space" | |
| assert len(images_unnorm.shape) == 4, "[b, c, h, w]" | |
| out = self.model( | |
| input_ids=self.input_ids, | |
| attention_mask=self.attention_mask, | |
| past_key_values=self.past_key_values, | |
| inputs_embeds=None, | |
| labels=self.labels, | |
| images=self.normalizer(images_unnorm), | |
| ) | |
| return out.loss.unsqueeze(0) | |
| def set_inputs( | |
| self, | |
| batch_text, | |
| past_key_values: torch.Tensor = None, | |
| to_device: bool = False, | |
| ): | |
| self.input_ids = self._prepare_text(batch_text) | |
| context_only = batch_text[0].get_prompt().split("ASSISTANT:")[0] + "ASSISTANT:" | |
| context_len = len(self.tokenizer.encode(context_only)) | |
| labels = copy.deepcopy(self.input_ids) | |
| labels[:, :context_len] = IGNORE_INDEX | |
| # labels[labels == self.stop_token_id] = IGNORE_INDEX | |
| # print(batch_text[0].get_prompt()) | |
| # print(self.tokenizer.decode(labels[labels != IGNORE_INDEX])) | |
| self.labels = labels | |
| self.attention_mask = self.input_ids.ne(self.tokenizer.pad_token_id) | |
| self.past_key_values = past_key_values | |
| def _prepare_images(self, batch: List[List[torch.Tensor]]) -> torch.Tensor: | |
| assert len(batch) == 1, "Only support batch size 1 (yet)" | |
| image_tensor = process_images(batch[0], self.image_processor, self.model.config) | |
| return image_tensor | |
| def _prepare_text(self, convs): | |
| input_ids = [ | |
| tokenizer_image_token(conv.get_prompt(), self.tokenizer, return_tensors='pt') for conv in convs | |
| ] | |
| input_ids = torch.stack(input_ids, dim=0).to(device='cuda', non_blocking=True) | |
| return input_ids | |
| def get_vqa_prompt(self, question, answer=None) -> str: | |
| if self.dataset_name == "vizwiz": | |
| self.prompt_suffix = "\nWhen the provided information is insufficient, respond with 'Unanswerable'.\nAnswer the question using a single word or phrase." | |
| elif self.dataset_name == "textvqa": | |
| self.prompt_suffix = "\nAnswer the question using a single word or phrase." | |
| elif self.dataset_name == "vqav2": | |
| self.prompt_suffix = "\nAnswer the question using a single word or phrase." | |
| else: | |
| raise ValueError(f"Unknown dataset: {self.dataset_name}") | |
| self.prompt_suffix = "" | |
| print(f"Unknown dataset: {DATASET_NAME}, using no prompt suffix.") | |
| qs = question + self.prompt_suffix | |
| if self.model.config.mm_use_im_start_end: | |
| qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs | |
| else: | |
| qs = DEFAULT_IMAGE_TOKEN + '\n' + qs | |
| conv = conv_templates[self.conv_mode].copy() | |
| conv.append_message(conv.roles[0], qs) | |
| conv.append_message(conv.roles[1], answer) | |
| return conv | |
| def get_caption_prompt(self, caption=None) -> str: | |
| qs = "Provide a short caption for this image." | |
| if self.model.config.mm_use_im_start_end: | |
| qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs | |
| else: | |
| qs = DEFAULT_IMAGE_TOKEN + '\n' + qs | |
| conv = conv_templates[self.conv_mode].copy() | |
| conv.append_message(conv.roles[0], qs) | |
| conv.append_message(conv.roles[1], caption) | |
| return conv | |