from typing import Any, ClassVar, Dict, Optional, Tuple, cast
import numpy as np
import torch
import torch.nn as nn
import kornia.augmentation as aug
from kornia.color.hsv import hsv_to_rgb, rgb_to_hsv
from .base import Augmentation
[docs]class RandomShift(Augmentation):
"""Random shift augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
shift_size (int): size to shift image.
"""
TYPE: ClassVar[str] = "random_shift"
_shift_size: int
_operation: Optional[nn.Sequential]
def __init__(self, shift_size: int = 4):
self._shift_size = shift_size
self._operation = None
def _setup(self, x: torch.Tensor) -> None:
height, width = x.shape[-2:]
self._operation = nn.Sequential(
nn.ReplicationPad2d(self._shift_size),
aug.RandomCrop((height, width)),
)
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]:
return {"shift_size": self._shift_size}
[docs]class Cutout(Augmentation):
"""Cutout augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
probability (float): probability to cutout.
"""
TYPE: ClassVar[str] = "cutout"
_probability: float
_operation: aug.RandomErasing
def __init__(self, probability: float = 0.5):
self._probability = probability
self._operation = aug.RandomErasing(p=probability)
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]:
return {"probability": self._probability}
[docs]class HorizontalFlip(Augmentation):
"""Horizontal flip augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
probability (float): probability to flip horizontally.
"""
TYPE: ClassVar[str] = "horizontal_flip"
_probability: float
_operation: aug.RandomHorizontalFlip
def __init__(self, probability: float = 0.1):
self._probability = probability
self._operation = aug.RandomHorizontalFlip(p=probability)
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]:
return {"probability": self._probability}
[docs]class VerticalFlip(Augmentation):
"""Vertical flip augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
probability (float): probability to flip vertically.
"""
TYPE: ClassVar[str] = "vertical_flip"
_probability: float
_operation: aug.RandomVerticalFlip
def __init__(self, probability: float = 0.1):
self._probability = probability
self._operation = aug.RandomVerticalFlip(p=probability)
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]:
return {"probability": self._probability}
[docs]class RandomRotation(Augmentation):
"""Random rotation augmentation.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
degree (float): range of degrees to rotate image.
"""
TYPE: ClassVar[str] = "random_rotation"
_degree: float
_operation: aug.RandomRotation
def __init__(self, degree: float = 5.0):
self._degree = degree
self._operation = aug.RandomRotation(degrees=degree)
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]:
return {"degree": self._degree}
[docs]class Intensity(Augmentation):
r"""Intensity augmentation.
.. math::
x' = x + n
where :math:`n \sim N(0, scale)`.
References:
* `Kostrikov et al., Image Augmentation Is All You Need: Regularizing
Deep Reinforcement Learning from Pixels.
<https://arxiv.org/abs/2004.13649>`_
Args:
scale (float): scale of multiplier.
"""
TYPE: ClassVar[str] = "intensity"
_scale: float
def __init__(self, scale: float = 0.1):
self._scale = scale
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]:
return {"scale": self._scale}
[docs]class ColorJitter(Augmentation):
"""Color Jitter augmentation.
This augmentation modifies the given images in the HSV channel spaces
as well as a contrast change.
This augmentation will be useful with the real world images.
References:
* `Laskin et al., Reinforcement Learning with Augmented Data.
<https://arxiv.org/abs/2004.14990>`_
Args:
brightness (tuple): brightness scale range.
contrast (tuple): contrast scale range.
saturation (tuple): saturation scale range.
hue (tuple): hue scale range.
"""
TYPE: ClassVar[str] = "color_jitter"
_brightness: Tuple[float, float]
_contrast: Tuple[float, float]
_saturation: Tuple[float, float]
_hue: Tuple[float, float]
def __init__(
self,
brightness: Tuple[float, float] = (0.6, 1.4),
contrast: Tuple[float, float] = (0.6, 1.4),
saturation: Tuple[float, float] = (0.6, 1.4),
hue: Tuple[float, float] = (-0.5, 0.5),
):
self._brightness = brightness
self._contrast = contrast
self._saturation = saturation
self._hue = hue
def _transform_hue(self, hsv: torch.Tensor) -> torch.Tensor:
scale = torch.empty(hsv.shape[0], 1, 1, 1, device=hsv.device)
scale = scale.uniform_(*self._hue) * 255.0 / 360.0
hsv[:, :, 0, :, :] = (hsv[:, :, 0, :, :] + scale) % 1
return hsv
def _transform_saturate(self, hsv: torch.Tensor) -> torch.Tensor:
scale = torch.empty(hsv.shape[0], 1, 1, 1, device=hsv.device)
scale.uniform_(*self._saturation)
hsv[:, :, 1, :, :] *= scale
return hsv.clamp(0, 1)
def _transform_brightness(self, hsv: torch.Tensor) -> torch.Tensor:
scale = torch.empty(hsv.shape[0], 1, 1, 1, device=hsv.device)
scale.uniform_(*self._brightness)
hsv[:, :, 2, :, :] *= scale
return hsv.clamp(0, 1)
def _transform_contrast(self, rgb: torch.Tensor) -> torch.Tensor:
scale = torch.empty(rgb.shape[0], 1, 1, 1, 1, device=rgb.device)
scale.uniform_(*self._contrast)
means = rgb.mean(dim=(3, 4), keepdim=True)
return ((rgb - means) * (scale + means)).clamp(0, 1)
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]:
return {
"brightness": self._brightness,
"contrast": self._contrast,
"saturation": self._saturation,
"hue": self._hue,
}