import numpy as np
import torch
import torch.nn as nn
import kornia.augmentation as aug
from kornia.color.hsv import hsv_to_rgb, rgb_to_hsv
from .base import Augmentation
[docs]class RandomShift(Augmentation):
""" Random shift augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
shift_size (int): size to shift image.
Attributes:
shift_size (int): size to shift image.
"""
def __init__(self, shift_size=4):
self.shift_size = shift_size
self._operation = None
def _setup(self, x):
height, width = x.shape[-2:]
self._operation = nn.Sequential(nn.ReplicationPad2d(self.shift_size),
aug.RandomCrop((height, width)))
[docs] def get_type(self):
""" Returns augmentation type.
Returns:
str: `random_shift`.
"""
return 'random_shift'
[docs] def get_params(self, deep=False):
""" Returns augmentation parameters.
Args:
deep (bool): flag to deeply copy objects.
Returns:
dict: augmentation parameters.
"""
return {'shift_size': self.shift_size}
[docs]class Cutout(Augmentation):
""" Cutout augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
probability (float): probability to cutout.
Attributes:
probability (float): probability to cutout.
"""
def __init__(self, probability=0.5):
self.probability = probability
self._operation = aug.RandomErasing(p=probability)
[docs] def get_type(self):
""" Returns augmentation type.
Returns:
str: `cutout`.
"""
return 'cutout'
[docs] def get_params(self, deep=False):
""" Returns augmentation parameters.
Args:
deep (bool): flag to deeply copy objects.
Returns:
dict: augmentation parameters.
"""
return {'probability': self.probability}
[docs]class HorizontalFlip(Augmentation):
""" Horizontal flip augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
probability (float): probability to flip horizontally.
Attributes:
probability (float): probability to flip horizontally.
"""
def __init__(self, probability=0.1):
self.probability = probability
self._operation = aug.RandomHorizontalFlip(p=probability)
[docs] def get_type(self):
""" Returns augmentation type.
Returns:
str: `horizontal_flip`.
"""
return 'horizontal_flip'
[docs] def get_params(self, deep=False):
""" Returns augmentation parameters.
Args:
deep (bool): flag to deeply copy objects.
Returns:
dict: augmentation parameters.
"""
return {'probability': self.probability}
[docs]class VerticalFlip(Augmentation):
""" Vertical flip augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
probability (float): probability to flip vertically.
Attributes:
probability (float): probability to flip vertically.
"""
def __init__(self, probability=0.1):
self.probability = probability
self._operation = aug.RandomVerticalFlip(p=probability)
[docs] def get_type(self):
""" Returns augmentation type.
Returns:
str: `vertical_flip`.
"""
return 'vertical_flip'
[docs] def get_params(self, deep=False):
""" Returns augmentation parameters.
Args:
deep (bool): flag to deeply copy objects.
Returns:
dict: augmentation parameters.
"""
return {'probability': self.probability}
[docs]class RandomRotation(Augmentation):
""" Random rotation augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
degree (float): range of degrees to rotate image.
Attributes:
degree (float): range of degrees to rotate image.
"""
def __init__(self, degree=5.0):
self.degree = degree
self._operation = aug.RandomRotation(degrees=degree)
[docs] def get_type(self):
""" Returns augmentation type.
Returns:
str: `random_rotation`.
"""
return 'random_rotation'
[docs] def get_params(self, deep=False):
""" Returns augmentation parameters.
Args:
deep (bool): flag to deeply copy objects.
Returns:
dict: augmentation parameters.
"""
return {'degree': self.degree}
[docs]class Intensity(Augmentation):
""" Intensity augmentation.
.. math::
x' = x + n
where :math:`n \\sim N(0, scale)`.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
scale (float): scale of multiplier.
Attributes:
scale (float): scale of multiplier.
"""
def __init__(self, scale=0.1):
self.scale = scale
[docs] def get_type(self):
""" Returns augmentation type.
Returns:
str: `intensity`.
"""
return 'intensity'
[docs] def get_params(self, deep=False):
""" Returns augmentation parameters.
Args:
deep (bool): flag to deeply copy objects.
Returns:
dict: augmentation parameters.
"""
return {'scale': self.scale}
[docs]class ColorJitter(Augmentation):
""" Color Jitter augmentation.
This augmentation modifies the given images in the HSV channel spaces
as well as a contrast change.
This augmentation will be useful with the real world images.
References:
* `Laskin et al., Reinforcement Learning with Augmented Data.
<https://arxiv.org/abs/2004.14990>`_
Args:
brightness (tuple): brightness scale range.
contrast (tuple): contrast scale range.
saturation (tuple): saturation scale range.
hue (tuple): hue scale range.
Attributes:
brightness (tuple): brightness scale range.
contrast (tuple): contrast scale range.
saturation (tuple): saturation scale range.
hue (tuple): hue scale range.
"""
def __init__(self,
brightness=(0.6, 1.4),
contrast=(0.6, 1.4),
saturation=(0.6, 1.4),
hue=(-0.5, 0.5)):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
def _transform_hue(self, hsv):
scale = torch.empty(hsv.shape[0], 1, 1, 1, device=hsv.device)
scale = scale.uniform_(*self.hue) * 255.0 / 360.0
hsv[:, :, 0, :, :] = (hsv[:, :, 0, :, :] + scale) % 1
return hsv
def _transform_saturate(self, hsv):
scale = torch.empty(hsv.shape[0], 1, 1, 1, device=hsv.device)
scale.uniform_(*self.saturation)
hsv[:, :, 1, :, :] *= scale
return hsv.clamp(0, 1)
def _transform_brightness(self, hsv):
scale = torch.empty(hsv.shape[0], 1, 1, 1, device=hsv.device)
scale.uniform_(*self.brightness)
hsv[:, :, 2, :, :] *= scale
return hsv.clamp(0, 1)
def _transform_contrast(self, rgb):
scale = torch.empty(rgb.shape[0], 1, 1, 1, 1, device=rgb.device)
scale.uniform_(*self.contrast)
means = rgb.mean(dim=(3, 4), keepdims=True)
return ((rgb - means) * (scale + means)).clamp(0, 1)
[docs] def get_type(self):
""" Returns augmentation type.
Returns:
str: `color_jitter`.
"""
return 'color_jitter'
[docs] def get_params(self, deep=False):
""" Returns augmentation parameters.
Args:
deep (bool): flag to deeply copy objects.
Returns:
dict: augmentation parameters.
"""
return {
'brightness': self.brightness,
'contrast': self.contrast,
'saturation': self.saturation,
'hue': self.hue
}