import torch
from .base import Augmentation
[docs]class SingleAmplitudeScaling(Augmentation):
""" Single Amplitude Scaling augmentation.
.. math::
x' = x + z
where :math:`z \\sim \\text{Unif}(minimum, maximum)`.
References:
* `Laskin et al., Reinforcement Learning with Augmented Data.
<https://arxiv.org/abs/2004.14990>`_
Args:
minimum (float): minimum amplitude scale.
maximum (float): maximum amplitude scale.
Attributes:
minimum (float): minimum amplitude scale.
maximum (float): maximum amplitude scale.
"""
def __init__(self, minimum=0.8, maximum=1.2):
self.minimum = minimum
self.maximum = maximum
[docs] def get_type(self):
""" Returns augmentation type.
Returns:
str: `single_amplitude_scaling`.
"""
return 'single_amplitude_scaling'
[docs] def get_params(self, deep=False):
""" Returns augmentation parameters.
Args:
deep (bool): flag to deeply copy objects.
Returns:
dict: augmentation parameters.
"""
return {'minimum': self.minimum, 'maximum': self.maximum}
[docs]class MultipleAmplitudeScaling(SingleAmplitudeScaling):
""" Multiple Amplitude Scaling augmentation.
.. math::
x' = x + z
where :math:`z \\sim \\text{Unif}(minimum, maximum)` and :math:`z`
is a vector with different amplitude scale on each.
References:
* `Laskin et al., Reinforcement Learning with Augmented Data.
<https://arxiv.org/abs/2004.14990>`_
Args:
minimum (float): minimum amplitude scale.
maximum (float): maximum amplitude scale.
Attributes:
minimum (float): minimum amplitude scale.
maximum (float): maximum amplitude scale.
"""
[docs] def get_type(self):
""" Returns augmentation type.
Returns:
str: `multiple_amplitude_scaling`.
"""
return 'multiple_amplitude_scaling'