### # Author: Kai Li # Date: 2021-06-19 22:34:13 # LastEditors: Kai Li # LastEditTime: 2021-08-30 20:01:43 ### import cv2 import random import numpy as np import torchvision __all__ = [ "Compose", "Normalize", "CenterCrop", "RgbToGray", "RandomCrop", "HorizontalFlip", ] class Compose(object): """Compose several preprocess together. Args: preprocess (list of ``Preprocess`` objects): list of preprocess to compose. """ def __init__(self, preprocess): self.preprocess = preprocess def __call__(self, sample): for t in self.preprocess: sample = t(sample) return sample def __repr__(self): format_string = self.__class__.__name__ + "(" for t in self.preprocess: format_string += "\n" format_string += " {0}".format(t) format_string += "\n)" return format_string class RgbToGray(object): """Convert image to grayscale. Converts a numpy.ndarray (H x W x C) in the range [0, 255] to a numpy.ndarray of shape (H x W x C) in the range [0.0, 1.0]. """ def __call__(self, frames): """ Args: img (numpy.ndarray): Image to be converted to gray. Returns: numpy.ndarray: grey image """ frames = np.stack([cv2.cvtColor(_, cv2.COLOR_RGB2GRAY) for _ in frames], axis=0) return frames def __repr__(self): return self.__class__.__name__ + "()" class Normalize(object): """Normalize a ndarray image with mean and standard deviation.""" def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, frames): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. Returns: Tensor: Normalized Tensor image. """ frames = (frames - self.mean) / self.std return frames def __repr__(self): return self.__class__.__name__ + "(mean={0}, std={1})".format( self.mean, self.std ) class CenterCrop(object): """Crop the given image at the center""" def __init__(self, size): self.size = size def __call__(self, frames): """ Args: img (numpy.ndarray): Images to be cropped. Returns: numpy.ndarray: Cropped image. """ t, h, w = frames.shape th, tw = self.size delta_w = int(round((w - tw)) / 2.0) delta_h = int(round((h - th)) / 2.0) frames = frames[:, delta_h : delta_h + th, delta_w : delta_w + tw] return frames class RandomCrop(object): """Crop the given image at the center""" def __init__(self, size): self.size = size def __call__(self, frames): """ Args: img (numpy.ndarray): Images to be cropped. Returns: numpy.ndarray: Cropped image. """ t, h, w = frames.shape th, tw = self.size delta_w = random.randint(0, w - tw) delta_h = random.randint(0, h - th) frames = frames[:, delta_h : delta_h + th, delta_w : delta_w + tw] return frames def __repr__(self): return self.__class__.__name__ + "(size={0})".format(self.size) class HorizontalFlip(object): """Flip image horizontally.""" def __init__(self, flip_ratio): self.flip_ratio = flip_ratio def __call__(self, frames): """ Args: img (numpy.ndarray): Images to be flipped with a probability flip_ratio Returns: numpy.ndarray: Cropped image. """ t, h, w = frames.shape if random.random() < self.flip_ratio: for index in range(t): frames[index] = cv2.flip(frames[index], 1) return frames def get_preprocessing_pipelines(): # -- preprocess for the video stream preprocessing = {} # -- LRW config crop_size = (88, 88) (mean, std) = (0.421, 0.165) preprocessing["train"] = Compose( [ Normalize(0.0, 255.0), RandomCrop(crop_size), HorizontalFlip(0.5), Normalize(mean, std), ] ) preprocessing["val"] = Compose( [Normalize(0.0, 255.0), CenterCrop(crop_size), Normalize(mean, std)] ) preprocessing["test"] = preprocessing["val"] return preprocessing def get_preprocessing_opt_pipelines(): preprocessing = {} # -- LRW config crop_size = (88, 88) (mean, std) = (0.421, 0.165) preprocessing["train"] = torchvision.transforms.Compose([ torchvision.transforms.Normalize(0.0, 255.0), torchvision.transforms.RandomCrop(crop_size), torchvision.transforms.RandomHorizontalFlip(0.5), torchvision.transforms.Normalize(mean, std) ]) preprocessing["val"] = torchvision.transforms.Compose([ torchvision.transforms.Normalize(0.0, 255.0), torchvision.transforms.CenterCrop(crop_size), torchvision.transforms.Normalize(mean, std) ]) preprocessing["test"] = preprocessing["val"] return preprocessing