Source code for d3rlpy.algos.base

from ..base import ImplBase, LearnableBase
from abc import ABCMeta, abstractmethod


class AlgoImplBase(ImplBase):
    @abstractmethod
    def build(self):
        pass

    @abstractmethod
    def save_policy(self, fname, as_onnx):
        pass

    @abstractmethod
    def predict_best_action(self, x):
        pass

    @abstractmethod
    def predict_value(self, x, action, with_std):
        pass

    @abstractmethod
    def sample_action(self, x):
        pass


class AlgoBase(LearnableBase):
    def __init__(self, n_epochs, batch_size, n_frames, scaler, augmentation,
                 dynamics, use_gpu):
        super().__init__(n_epochs, batch_size, n_frames, scaler, augmentation,
                         use_gpu)
        self.dynamics = dynamics

    def save_policy(self, fname, as_onnx=False):
        """ Save the greedy-policy computational graph as TorchScript or ONNX.

        .. code-block:: python

            # save as TorchScript
            algo.save_policy('policy.pt')

            # save as ONNX
            algo.save_policy('policy.onnx', as_onnx=True)

        The artifacts saved with this method will work without d3rlpy.
        This method is especially useful to deploy the learned policy to
        production environments or embedding systems.

        See also

            * https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html (for Python).
            * https://pytorch.org/tutorials/advanced/cpp_export.html (for C++).
            * https://onnx.ai (for ONNX)

        Args:
            fname (str): destination file path.
            as_onnx (bool): flag to save as ONNX format.

        """
        self.impl.save_policy(fname, as_onnx)

    def predict(self, x):
        """ Returns greedy actions.

        .. code-block:: python

            # 100 observations with shape of (10,)
            x = np.random.random((100, 10))

            actions = algo.predict(x)
            # actions.shape == (100, action size) for continuous control
            # actions.shape == (100,) for discrete control

        Args:
            x (numpy.ndarray): observations

        Returns:
            numpy.ndarray: greedy actions

        """
        return self.impl.predict_best_action(x)

    def predict_value(self, x, action, with_std=False):
        """ Returns predicted action-values.

        .. code-block:: python

            # 100 observations with shape of (10,)
            x = np.random.random((100, 10))

            # for continuous control
            # 100 actions with shape of (2,)
            actions = np.random.random((100, 2))

            # for discrete control
            # 100 actions in integer values
            actions = np.random.randint(2, size=100)

            values = algo.predict_value(x, actions)
            # values.shape == (100,)

            values, stds = algo.predict_value(x, actions, with_std=True)
            # stds.shape  == (100,)

        Args:
            x (numpy.ndarray): observations
            action (numpy.ndarray): actions
            with_std (bool): flag to return standard deviation of ensemble
                estimation. This deviation reflects uncertainty for the given
                observations. This uncertainty will be more accurate if you
                enable `bootstrap` flag and increase `n_critics` value.

        Returns:
            numpy.ndarray: predicted action-values

        """
        return self.impl.predict_value(x, action, with_std)

    def sample_action(self, x):
        """ Returns sampled actions.

        The sampled actions are identical to the output of `predict` method if
        the policy is deterministic.

        Args:
            x (numpy.ndarray): observations.

        Returns:
            numpy.ndarray: sampled actions.

        """
        return self.impl.sample_action(x)

    def _generate_new_data(self, transitions):
        new_data = []
        if self.dynamics:
            new_data += self.dynamics.generate(self, transitions)
        return new_data