| | from typing import Dict, List, Optional, Tuple, Union, Iterable |
| |
|
| | import numpy as np |
| | import torch |
| | import transformers |
| | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| | from transformers.image_transforms import ( |
| | ChannelDimension, |
| | get_resize_output_image_size, |
| | rescale, |
| | resize, |
| | to_channel_dimension_format, |
| | ) |
| | from transformers.image_utils import ( |
| | ImageInput, |
| | PILImageResampling, |
| | infer_channel_dimension_format, |
| | get_channel_dimension_axis, |
| | make_list_of_images, |
| | to_numpy_array, |
| | valid_images, |
| | ) |
| | from transformers.utils import is_torch_tensor |
| |
|
| |
|
| | class FaceSegformerImageProcessor(BaseImageProcessor): |
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | self.image_size = kwargs.get("image_size", (224, 224)) |
| | self.normalize_mean = kwargs.get("normalize_mean", [0.485, 0.456, 0.406]) |
| | self.normalize_std = kwargs.get("normalize_std", [0.229, 0.224, 0.225]) |
| | self.resample = kwargs.get("resample", PILImageResampling.BILINEAR) |
| | self.data_format = kwargs.get("data_format", ChannelDimension.FIRST) |
| |
|
| | @staticmethod |
| | def normalize( |
| | image: np.ndarray, |
| | mean: Union[float, Iterable[float]], |
| | std: Union[float, Iterable[float]], |
| | max_pixel_value: float = 255.0, |
| | data_format: Optional[ChannelDimension] = None, |
| | input_data_format: Optional[Union[str, ChannelDimension]] = None, |
| | ) -> np.ndarray: |
| | """ |
| | Copied from: |
| | https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/image_transforms.py#L209 |
| | |
| | BUT uses the formula from albumentations: |
| | https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize |
| | |
| | img = (img - mean * max_pixel_value) / (std * max_pixel_value) |
| | """ |
| | if not isinstance(image, np.ndarray): |
| | raise ValueError("image must be a numpy array") |
| |
|
| | if input_data_format is None: |
| | input_data_format = infer_channel_dimension_format(image) |
| | channel_axis = get_channel_dimension_axis( |
| | image, input_data_format=input_data_format |
| | ) |
| | num_channels = image.shape[channel_axis] |
| |
|
| | |
| | |
| | if not np.issubdtype(image.dtype, np.floating): |
| | image = image.astype(np.float32) |
| |
|
| | if isinstance(mean, Iterable): |
| | if len(mean) != num_channels: |
| | raise ValueError( |
| | f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}" |
| | ) |
| | else: |
| | mean = [mean] * num_channels |
| | mean = np.array(mean, dtype=image.dtype) |
| |
|
| | if isinstance(std, Iterable): |
| | if len(std) != num_channels: |
| | raise ValueError( |
| | f"std must have {num_channels} elements if it is an iterable, got {len(std)}" |
| | ) |
| | else: |
| | std = [std] * num_channels |
| | std = np.array(std, dtype=image.dtype) |
| |
|
| | |
| | if input_data_format == ChannelDimension.LAST: |
| | image = (image - mean * max_pixel_value) / (std * max_pixel_value) |
| | else: |
| | image = ((image.T - mean * max_pixel_value) / (std * max_pixel_value)).T |
| |
|
| | image = ( |
| | to_channel_dimension_format(image, data_format, input_data_format) |
| | if data_format is not None |
| | else image |
| | ) |
| | return image |
| |
|
| | def resize( |
| | self, |
| | image: np.ndarray, |
| | size: Dict[str, int], |
| | resample: PILImageResampling = PILImageResampling.BICUBIC, |
| | data_format: Optional[Union[str, ChannelDimension]] = None, |
| | input_data_format: Optional[Union[str, ChannelDimension]] = None, |
| | **kwargs, |
| | ) -> np.ndarray: |
| | """ |
| | Copied from: |
| | https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py |
| | """ |
| | default_to_square = True |
| | if "shortest_edge" in size: |
| | size = size["shortest_edge"] |
| | default_to_square = False |
| | elif "height" in size and "width" in size: |
| | size = (size["height"], size["width"]) |
| | else: |
| | raise ValueError( |
| | "Size must contain either 'shortest_edge' or 'height' and 'width'." |
| | ) |
| |
|
| | output_size = get_resize_output_image_size( |
| | image, |
| | size=size, |
| | default_to_square=default_to_square, |
| | input_data_format=input_data_format, |
| | ) |
| | return resize( |
| | image, |
| | size=output_size, |
| | resample=resample, |
| | data_format=data_format, |
| | input_data_format=input_data_format, |
| | **kwargs, |
| | ) |
| |
|
| | def __call__(self, images: ImageInput, masks: ImageInput = None, **kwargs): |
| | """ |
| | Adapted from: |
| | https://github.com/huggingface/transformers/blob/3eddda1111f70f3a59485e08540e8262b927e867/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py |
| | """ |
| | |
| | images = make_list_of_images(images) |
| |
|
| | |
| | if not valid_images(images): |
| | raise ValueError( |
| | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " |
| | "torch.Tensor, tf.Tensor or jax.ndarray." |
| | ) |
| |
|
| | |
| | images = [to_numpy_array(image) for image in images] |
| |
|
| | |
| | input_data_format = kwargs.get("input_data_format") |
| | if input_data_format is None: |
| | |
| | input_data_format = infer_channel_dimension_format(images[0]) |
| |
|
| | |
| | |
| | if kwargs.get("do_training", False) is True: |
| | if mask is None: |
| | raise ValueError("must pass masks if doing training.") |
| | |
| | raise NotImplementedError("not yet implemented.") |
| | |
| | else: |
| | |
| | images = [ |
| | self.resize( |
| | image=image, |
| | size={"height": self.image_size[0], "width": self.image_size[1]}, |
| | resample=kwargs.get("resample") or self.resample, |
| | input_data_format=input_data_format, |
| | ) |
| | for image in images |
| | ] |
| | images = [ |
| | self.normalize( |
| | image=image, |
| | mean=kwargs.get("normalize_mean") or self.normalize_mean, |
| | std=kwargs.get("normalize_std") or self.normalize_std, |
| | input_data_format=input_data_format, |
| | ) |
| | for image in images |
| | ] |
| | |
| | images = [ |
| | to_channel_dimension_format( |
| | image, |
| | kwargs.get("data_format") or self.data_format, |
| | input_channel_dim=input_data_format, |
| | ) |
| | for image in images |
| | ] |
| |
|
| | data = {"pixel_values": images} |
| | return BatchFeature(data=data, tensor_type="pt") |
| |
|
| | |
| | def post_process_semantic_segmentation( |
| | self, outputs, target_sizes: List[Tuple] = None |
| | ): |
| | """ |
| | Converts the output of [`SegformerForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. |
| | |
| | Args: |
| | outputs ([`SegformerForSemanticSegmentation`]): |
| | Raw outputs of the model. |
| | target_sizes (`List[Tuple]` of length `batch_size`, *optional*): |
| | List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, |
| | predictions will not be resized. |
| | |
| | Returns: |
| | semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic |
| | segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is |
| | specified). Each entry of each `torch.Tensor` correspond to a semantic class id. |
| | """ |
| | |
| | logits = outputs.logits |
| |
|
| | |
| | if target_sizes is not None: |
| | if len(logits) != len(target_sizes): |
| | raise ValueError( |
| | "Make sure that you pass in as many target sizes as the batch dimension of the logits" |
| | ) |
| |
|
| | if is_torch_tensor(target_sizes): |
| | target_sizes = target_sizes.numpy() |
| |
|
| | semantic_segmentation = [] |
| |
|
| | for idx in range(len(logits)): |
| | resized_logits = torch.nn.functional.interpolate( |
| | logits[idx].unsqueeze(dim=0), |
| | size=target_sizes[idx], |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| | semantic_map = resized_logits[0].argmax(dim=0) |
| | semantic_segmentation.append(semantic_map) |
| | else: |
| | semantic_segmentation = logits.argmax(dim=1) |
| | semantic_segmentation = [ |
| | semantic_segmentation[i] for i in range(semantic_segmentation.shape[0]) |
| | ] |
| |
|
| | return semantic_segmentation |
| |
|