quantization / llm-awq /tinychat /utils /llava_image_processing.py
chen459664's picture
Add files using upload-large-folder tool
02bc7b8 verified
# Modified from https://github.com/haotian-liu/LLaVA
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from PIL import Image
from io import BytesIO
import requests
import os
import base64
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def load_image(image_file):
if image_file.startswith("http://") or image_file.startswith("https://"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def load_images(image_files):
out = []
for image_file in image_files:
image = load_image(image_file)
out.append(image)
return out
def vis_images(image_files):
if len(image_files) == 1:
image = image_files[0]
os.system(f"termvisage --query-timeout 1 {image} -H left --height 12")
else:
# Concat images
system_inst = "convert "
inst_template1 = " \\( {image} -background none -resize x500 \\) "
inst_template2 = " \\( {image} -background none -resize x500 -splice 100x0 \\) "
count = 0
for image in image_files:
count += 1
if count == 1:
system_inst += inst_template1.format(image=image)
else:
system_inst += inst_template2.format(image=image)
system_inst += " +append .vis.jpg"
os.system(system_inst)
os.system(f"termvisage --query-timeout 1 .vis.jpg -H left")
def expand2square(pil_img, background_color):
"""
Copy from Llava codebase for image preprocessing.
"""
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg):
"""
Copy from Llava codebase for image preprocessing.
"""
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(
image, tuple(int(x * 255) for x in image_processor.image_mean)
)
image = image_processor.preprocess(image, return_tensors="pt")[
"pixel_values"
][0]
if "intern" in image_processor.__class__.__name__.lower():
# special case
new_images.append(image.unsqueeze(0))
else:
new_images.append(image)
else:
ret = image_processor(images, return_tensors="pt")["pixel_values"]
if "intern" in image_processor.__class__.__name__.lower():
# special case
ret = [x.unsqueeze(0) for x in ret]
return ret
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images