siglip2-navit-andes-bf16 / image_processor_siglip2_navit.py
wtzhang-nlp's picture
Upload folder using huggingface_hub
ae06ed2 verified
# coding=utf-8
"""Image processor class for AndesVL with token-based sizing."""
import math
from typing import Optional, Union
import numpy as np
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import (
convert_to_rgb,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_flat_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
)
from transformers.utils import TensorType, is_vision_available, logging
logger = logging.get_logger(__name__)
if is_vision_available():
from PIL import Image
class SigLIP2NaViTImageProcessor(BaseImageProcessor):
r"""
Constructs an AndesVL image processor with token-based compute budget.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image based on token budget.
patch_size (`int`, *optional*, defaults to `14`):
The patch size used by the vision encoder.
max_tokens (`int`, *optional*, defaults to `1024`):
Maximum number of vision tokens (controls compute budget).
min_tokens (`int`, *optional*, defaults to `4`):
Minimum number of vision tokens (ensures sufficient visual info).
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image.
image_mean (`float` or `list[float]`, *optional*, defaults to CLIP mean):
Mean to use if normalizing the image.
image_std (`float` or `list[float]`, *optional*, defaults to CLIP std):
Standard deviation to use if normalizing the image.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to align to base grid. If False, stretches image.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
patch_size: int = 14,
max_tokens: int = 1024,
min_tokens: int = 4,
merge_size: int = 2,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
do_convert_rgb: bool = True,
do_pad: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.do_resize = do_resize
self.patch_size = patch_size
self.max_tokens = max_tokens
self.min_tokens = min_tokens
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self.do_pad = do_pad
self.merge_size = merge_size
self._valid_processor_keys = [
"images",
"do_resize",
"patch_size",
"max_tokens",
"min_tokens",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"do_pad",
"return_tensors",
"data_format",
"input_data_format",
]
@property
def base(self) -> int:
"""Base grid size (patch_size * 2 for 2x2 token merging)."""
return self.patch_size * self.merge_size
def _compute_target_size(
self,
width: int,
height: int,
max_tokens: int,
min_tokens: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
"""
Compute target image size and bounding box size based on token budget.
Returns:
(new_image_size, bounding_box_size): Both as (width, height) tuples
"""
base = self.base
token_size = base * base
max_area = token_size * max_tokens
min_area = token_size * min_tokens
area = width * height
# Determine target area based on constraints
if area > max_area:
target_area = max_area
elif area < min_area:
target_area = min_area
else:
target_area = area
# Compute bounding box from target area while preserving aspect ratio
aspect = width / height
box_w = math.sqrt(target_area * aspect)
box_h = math.sqrt(target_area / aspect)
# Clamp to [base, inf) and align to base multiple
box_w = max(int(box_w), base)
box_h = max(int(box_h), base)
box_w = (box_w + base - 1) // base * base
box_h = (box_h + base - 1) // base * base
# Compute final image size within bounding box
scale = min(box_w / width, box_h / height)
new_w, new_h = round(width * scale), round(height * scale)
return (new_w, new_h), (box_w, box_h)
def resize_and_pad(
self,
image: np.ndarray,
max_tokens: int,
min_tokens: int,
do_pad: bool = True,
resample: PILImageResampling = PILImageResampling.BICUBIC,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""
Resize image based on token budget, optionally padding to grid alignment.
Args:
image: Input image as numpy array
max_tokens: Maximum vision tokens
min_tokens: Minimum vision tokens
do_pad: If True, pad to bounding box; if False, stretch to bounding box
resample: Resampling filter
input_data_format: Channel dimension format of input
Returns:
Processed image as numpy array
"""
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
# Get current size
if input_data_format == ChannelDimension.FIRST:
_, height, width = image.shape
else:
height, width, _ = image.shape
# Compute target sizes
(new_w, new_h), (box_w, box_h) = self._compute_target_size(
width, height, max_tokens, min_tokens
)
# Fast path: no change needed
if new_w == width and new_h == height and box_w == width and box_h == height:
return image
# Force resize (stretch) to bounding box
if not do_pad:
return resize(
image,
size=(box_h, box_w), # resize expects (height, width)
resample=resample,
input_data_format=input_data_format,
)
# Resize to new size
resized = resize(
image,
size=(new_h, new_w),
resample=resample,
input_data_format=input_data_format,
)
# No padding needed
if new_w == box_w and new_h == box_h:
return resized
# Create padded canvas with mean color as background
# Check if image is in 0-1 range or 0-255 range
is_scaled = is_scaled_image(resized)
if input_data_format == ChannelDimension.FIRST:
num_channels = resized.shape[0]
canvas = np.zeros((num_channels, box_h, box_w), dtype=resized.dtype)
# Fill with mean values (scaled appropriately based on image range)
for c in range(num_channels):
canvas[c, :, :] = (
self.image_mean[c] if is_scaled else self.image_mean[c] * 255
)
# Paste centered
pad_top = (box_h - new_h) // 2
pad_left = (box_w - new_w) // 2
canvas[:, pad_top : pad_top + new_h, pad_left : pad_left + new_w] = resized
else:
num_channels = resized.shape[-1]
canvas = np.zeros((box_h, box_w, num_channels), dtype=resized.dtype)
# Fill with mean values (scaled appropriately based on image range)
for c in range(num_channels):
canvas[:, :, c] = (
self.image_mean[c] if is_scaled else self.image_mean[c] * 255
)
# Paste centered
pad_top = (box_h - new_h) // 2
pad_left = (box_w - new_w) // 2
canvas[pad_top : pad_top + new_h, pad_left : pad_left + new_w, :] = resized
return canvas
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
resample: Optional[PILImageResampling] = None,
do_pad: Optional[bool] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
do_convert_rgb: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images.
do_resize (`bool`, *optional*):
Whether to resize the image based on token budget.
max_tokens (`int`, *optional*):
Maximum number of vision tokens.
min_tokens (`int`, *optional*):
Minimum number of vision tokens.
resample (`PILImageResampling`, *optional*):
Resampling filter to use if resizing.
do_pad (`bool`, *optional*):
Whether to pad (True) or stretch (False) to grid alignment.
do_rescale (`bool`, *optional*):
Whether to rescale the image.
rescale_factor (`float`, *optional*):
Rescale factor.
do_normalize (`bool`, *optional*):
Whether to normalize the image.
image_mean (`float` or `list[float]`, *optional*):
Image mean for normalization.
image_std (`float` or `list[float]`, *optional*):
Image std for normalization.
do_convert_rgb (`bool`, *optional*):
Whether to convert to RGB.
return_tensors (`str` or `TensorType`, *optional*):
Type of tensors to return.
data_format (`ChannelDimension`, *optional*):
Output channel dimension format.
input_data_format (`ChannelDimension`, *optional*):
Input channel dimension format.
Returns:
`BatchFeature` with `pixel_values` and `image_sizes` keys.
"""
# Use instance defaults if not specified
do_resize = do_resize if do_resize is not None else self.do_resize
max_tokens = max_tokens if max_tokens is not None else self.max_tokens
min_tokens = min_tokens if min_tokens is not None else self.min_tokens
resample = resample if resample is not None else self.resample
do_pad = do_pad if do_pad is not None else self.do_pad
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = (
rescale_factor if rescale_factor is not None else self.rescale_factor
)
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = (
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
)
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=self._valid_processor_keys,
)
images = make_flat_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# Convert to numpy arrays
images = [to_numpy_array(image) for image in images]
if do_rescale and is_scaled_image(images[0]):
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(images[0])
all_images = []
image_sizes = []
for image in images:
if do_resize:
image = self.resize_and_pad(
image=image,
max_tokens=max_tokens,
min_tokens=min_tokens,
do_pad=do_pad,
resample=resample,
input_data_format=input_data_format,
)
if do_rescale:
image = self.rescale(
image=image,
scale=rescale_factor,
input_data_format=input_data_format,
)
if do_normalize:
image = self.normalize(
image=image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
# Record output size for token count calculation
if input_data_format == ChannelDimension.FIRST:
c, h, w = image.shape
else:
h, w, c = image.shape
image = to_channel_dimension_format(
image, data_format, input_channel_dim=input_data_format
)
grid_h = h // self.patch_size
grid_w = w // self.patch_size
image_sizes.append((grid_h, grid_w))
patches = image.reshape(
c,
grid_h // self.merge_size,
self.merge_size,
self.patch_size,
grid_w // self.merge_size,
self.merge_size,
self.patch_size,
)
patches = patches.transpose(
1, 4, 2, 5, 0, 3, 6
) # (grid_h//m, grid_w//m, m, m, c, p, p)
image = patches.reshape(
-1, c * self.patch_size * self.patch_size
) # (num_tokens, c*p*p)
all_images.append(image)
pixel_values = np.concatenate(all_images, axis=0)
image_grid_hw = np.array(image_sizes)
return BatchFeature(
data={"pixel_values": pixel_values, "image_grid_hw": image_grid_hw},
tensor_type=return_tensors,
)
def get_num_tokens(self, image_size: tuple[int, int]) -> int:
"""Calculate the number of vision tokens for a given image size."""
h, w = image_size
return (h // self.base) * (w // self.base)
__all__ = ["SigLIP2NaViTImageProcessor"]