Dolphin / look2hear /datas /transform.py
JusperLee's picture
clean repo without raw binaries
0cd6025
###
# Author: Kai Li
# Date: 2021-06-19 22:34:13
# LastEditors: Kai Li
# LastEditTime: 2021-08-30 20:01:43
###
import cv2
import random
import numpy as np
import torchvision
__all__ = [
"Compose",
"Normalize",
"CenterCrop",
"RgbToGray",
"RandomCrop",
"HorizontalFlip",
]
class Compose(object):
"""Compose several preprocess together.
Args:
preprocess (list of ``Preprocess`` objects): list of preprocess to compose.
"""
def __init__(self, preprocess):
self.preprocess = preprocess
def __call__(self, sample):
for t in self.preprocess:
sample = t(sample)
return sample
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.preprocess:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class RgbToGray(object):
"""Convert image to grayscale.
Converts a numpy.ndarray (H x W x C) in the range
[0, 255] to a numpy.ndarray of shape (H x W x C) in the range [0.0, 1.0].
"""
def __call__(self, frames):
"""
Args:
img (numpy.ndarray): Image to be converted to gray.
Returns:
numpy.ndarray: grey image
"""
frames = np.stack([cv2.cvtColor(_, cv2.COLOR_RGB2GRAY) for _ in frames], axis=0)
return frames
def __repr__(self):
return self.__class__.__name__ + "()"
class Normalize(object):
"""Normalize a ndarray image with mean and standard deviation."""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, frames):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor image.
"""
frames = (frames - self.mean) / self.std
return frames
def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
self.mean, self.std
)
class CenterCrop(object):
"""Crop the given image at the center"""
def __init__(self, size):
self.size = size
def __call__(self, frames):
"""
Args:
img (numpy.ndarray): Images to be cropped.
Returns:
numpy.ndarray: Cropped image.
"""
t, h, w = frames.shape
th, tw = self.size
delta_w = int(round((w - tw)) / 2.0)
delta_h = int(round((h - th)) / 2.0)
frames = frames[:, delta_h : delta_h + th, delta_w : delta_w + tw]
return frames
class RandomCrop(object):
"""Crop the given image at the center"""
def __init__(self, size):
self.size = size
def __call__(self, frames):
"""
Args:
img (numpy.ndarray): Images to be cropped.
Returns:
numpy.ndarray: Cropped image.
"""
t, h, w = frames.shape
th, tw = self.size
delta_w = random.randint(0, w - tw)
delta_h = random.randint(0, h - th)
frames = frames[:, delta_h : delta_h + th, delta_w : delta_w + tw]
return frames
def __repr__(self):
return self.__class__.__name__ + "(size={0})".format(self.size)
class HorizontalFlip(object):
"""Flip image horizontally."""
def __init__(self, flip_ratio):
self.flip_ratio = flip_ratio
def __call__(self, frames):
"""
Args:
img (numpy.ndarray): Images to be flipped with a probability flip_ratio
Returns:
numpy.ndarray: Cropped image.
"""
t, h, w = frames.shape
if random.random() < self.flip_ratio:
for index in range(t):
frames[index] = cv2.flip(frames[index], 1)
return frames
def get_preprocessing_pipelines():
# -- preprocess for the video stream
preprocessing = {}
# -- LRW config
crop_size = (88, 88)
(mean, std) = (0.421, 0.165)
preprocessing["train"] = Compose(
[
Normalize(0.0, 255.0),
RandomCrop(crop_size),
HorizontalFlip(0.5),
Normalize(mean, std),
]
)
preprocessing["val"] = Compose(
[Normalize(0.0, 255.0), CenterCrop(crop_size), Normalize(mean, std)]
)
preprocessing["test"] = preprocessing["val"]
return preprocessing
def get_preprocessing_opt_pipelines():
preprocessing = {}
# -- LRW config
crop_size = (88, 88)
(mean, std) = (0.421, 0.165)
preprocessing["train"] = torchvision.transforms.Compose([
torchvision.transforms.Normalize(0.0, 255.0),
torchvision.transforms.RandomCrop(crop_size),
torchvision.transforms.RandomHorizontalFlip(0.5),
torchvision.transforms.Normalize(mean, std)
])
preprocessing["val"] = torchvision.transforms.Compose([
torchvision.transforms.Normalize(0.0, 255.0),
torchvision.transforms.CenterCrop(crop_size),
torchvision.transforms.Normalize(mean, std)
])
preprocessing["test"] = preprocessing["val"]
return preprocessing