from typing import Any, ClassVar, Dict
import torch
from .base import Augmentation
[docs]class SingleAmplitudeScaling(Augmentation):
r"""Single Amplitude Scaling augmentation.
.. math::
x' = x + z
where :math:`z \sim \text{Unif}(minimum, maximum)`.
References:
* `Laskin et al., Reinforcement Learning with Augmented Data.
<https://arxiv.org/abs/2004.14990>`_
Args:
minimum (float): minimum amplitude scale.
maximum (float): maximum amplitude scale.
"""
TYPE: ClassVar[str] = "single_amplitude_scaling"
_minimum: float
_maximum: float
def __init__(self, minimum: float = 0.8, maximum: float = 1.2):
self._minimum = minimum
self._maximum = maximum
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]:
return {"minimum": self._minimum, "maximum": self._maximum}
[docs]class MultipleAmplitudeScaling(SingleAmplitudeScaling):
r"""Multiple Amplitude Scaling augmentation.
.. math::
x' = x + z
where :math:`z \sim \text{Unif}(minimum, maximum)` and :math:`z`
is a vector with different amplitude scale on each.
References:
* `Laskin et al., Reinforcement Learning with Augmented Data.
<https://arxiv.org/abs/2004.14990>`_
Args:
minimum (float): minimum amplitude scale.
maximum (float): maximum amplitude scale.
"""
TYPE: ClassVar[str] = "multiple_amplitude_scaling"