Source code for d3rlpy.augmentation.vector

from typing import Any, ClassVar, Dict

import torch

from .base import Augmentation


[docs]class SingleAmplitudeScaling(Augmentation): r"""Single Amplitude Scaling augmentation. .. math:: x' = x + z where :math:`z \sim \text{Unif}(minimum, maximum)`. References: * `Laskin et al., Reinforcement Learning with Augmented Data. <https://arxiv.org/abs/2004.14990>`_ Args: minimum (float): minimum amplitude scale. maximum (float): maximum amplitude scale. """ TYPE: ClassVar[str] = "single_amplitude_scaling" _minimum: float _maximum: float def __init__(self, minimum: float = 0.8, maximum: float = 1.2): self._minimum = minimum self._maximum = maximum
[docs] def transform(self, x: torch.Tensor) -> torch.Tensor: z = torch.empty(x.shape[0], 1, device=x.device) z.uniform_(self._minimum, self._maximum) return x * z
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]: return {"minimum": self._minimum, "maximum": self._maximum}
[docs]class MultipleAmplitudeScaling(SingleAmplitudeScaling): r"""Multiple Amplitude Scaling augmentation. .. math:: x' = x + z where :math:`z \sim \text{Unif}(minimum, maximum)` and :math:`z` is a vector with different amplitude scale on each. References: * `Laskin et al., Reinforcement Learning with Augmented Data. <https://arxiv.org/abs/2004.14990>`_ Args: minimum (float): minimum amplitude scale. maximum (float): maximum amplitude scale. """ TYPE: ClassVar[str] = "multiple_amplitude_scaling"
[docs] def transform(self, x: torch.Tensor) -> torch.Tensor: z = torch.empty(*x.shape, device=x.device) z.uniform_(self._minimum, self._maximum) return x * z