# Modified from https://github.com/haotian-liu/LLaVA # Copyright 2023 Haotian Liu # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from PIL import Image from io import BytesIO import requests import os import base64 def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def load_image(image_file): if image_file.startswith("http://") or image_file.startswith("https://"): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(image_file).convert("RGB") return image def load_images(image_files): out = [] for image_file in image_files: image = load_image(image_file) out.append(image) return out def vis_images(image_files): if len(image_files) == 1: image = image_files[0] os.system(f"termvisage --query-timeout 1 {image} -H left --height 12") else: # Concat images system_inst = "convert " inst_template1 = " \\( {image} -background none -resize x500 \\) " inst_template2 = " \\( {image} -background none -resize x500 -splice 100x0 \\) " count = 0 for image in image_files: count += 1 if count == 1: system_inst += inst_template1.format(image=image) else: system_inst += inst_template2.format(image=image) system_inst += " +append .vis.jpg" os.system(system_inst) os.system(f"termvisage --query-timeout 1 .vis.jpg -H left") def expand2square(pil_img, background_color): """ Copy from Llava codebase for image preprocessing. """ width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def process_images(images, image_processor, model_cfg): """ Copy from Llava codebase for image preprocessing. """ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) new_images = [] if image_aspect_ratio == "pad": for image in images: image = expand2square( image, tuple(int(x * 255) for x in image_processor.image_mean) ) image = image_processor.preprocess(image, return_tensors="pt")[ "pixel_values" ][0] if "intern" in image_processor.__class__.__name__.lower(): # special case new_images.append(image.unsqueeze(0)) else: new_images.append(image) else: ret = image_processor(images, return_tensors="pt")["pixel_values"] if "intern" in image_processor.__class__.__name__.lower(): # special case ret = [x.unsqueeze(0) for x in ret] return ret if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images