Source code for d3rlpy.augmentation.pipeline

from abc import ABCMeta, abstractmethod


class AugmentationPipeline(metaclass=ABCMeta):
    def __init__(self, augmentations):
        self.augmentations = augmentations

    def append(self, augmentation):
        """ Append augmentation to pipeline.

        Args:
            augmentation (d3rlpy.augmentation.base.Augmentation): augmentation.

        """
        self.augmentations.append(augmentation)

    def get_augmentation_types(self):
        """ Returns augmentation types.

        Returns:
            list(str): list of augmentation types.

        """
        return [aug.get_type() for aug in self.augmentations]

    def get_augmentation_params(self):
        """ Returns augmentation parameters.

        Args:
            deep (bool): flag to deeply copy objects.

        Returns:
            list(dict): list of augmentation parameters.

        """
        return [aug.get_params() for aug in self.augmentations]

    @abstractmethod
    def get_params(self, deep=False):
        """ Returns pipeline parameters.

        Returns:
            dict: piple parameters.

        """
        pass

    def transform(self, x):
        """ Returns observation processed by all augmentations.

        Args:
            x (torch.Tensor): observation tensor.

        Returns:
            torch.Tensor: processed observation tensor.

        """
        if not self.augmentations:
            return x

        for augmentation in self.augmentations:
            x = augmentation.transform(x)

        return x

    @abstractmethod
    def process(self, func, inputs, targets):
        """ Runs a given function while augmenting inputs.

        Args:
            func (callable): function to compute.
            inputs (dict): inputs to the func.
            target (list(str)): list of argument names to augment.

        Returns:
            torch.Tensor: the computation result.

        """
        pass


[docs]class DrQPipeline(AugmentationPipeline): """ Data-reguralized Q augmentation pipeline. References: * `Kostrikov et al., Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels. <https://arxiv.org/abs/2004.13649>`_ Args: augmentations (list(d3rlpy.augmentation.base.Augmentation or str)): list of augmentations or augmentation types. n_mean (int): the number of computations to average Attributes: augmentations (list(d3rlpy.augmentation.base.Augmentation)): list of augmentations. n_mean (int): the number of computations to average """ def __init__(self, augmentations=None, n_mean=1): if augmentations is None: augmentations = [] super().__init__(augmentations) self.n_mean = n_mean
[docs] def get_params(self, deep=False): return {'n_mean': self.n_mean}
[docs] def process(self, func, inputs, targets): ret = 0.0 for _ in range(self.n_mean): kwargs = dict(inputs) for target in targets: kwargs[target] = self.transform(kwargs[target]) ret += func(**kwargs) return ret / self.n_mean