Source code for d3rlpy.dynamics.base

import numpy as np

from abc import abstractmethod
from ..base import ImplBase, LearnableBase
from ..dataset import Transition, TransitionMiniBatch


class DynamicsImplBase(ImplBase):
    @abstractmethod
    def predict(self, x, action, with_variance):
        pass

    @abstractmethod
    def generate(self, x, action):
        pass


class DynamicsBase(LearnableBase):
    def __init__(self, batch_size, n_frames, n_transitions, horizon, scaler,
                 augmentation, use_gpu):
        super().__init__(batch_size, n_frames, scaler, augmentation, use_gpu)
        self.n_transitions = n_transitions
        self.horizon = horizon

    def predict(self, x, action, with_variance=False):
        """ Returns predicted observation and reward.

        Args:
            x (numpy.ndarray): observation
            action (numpy.ndarray): action
            with_variance (bool): flag to return prediction variance.

        Returns:
            tuple: tuple of predicted observation and reward.

        """
        observations, rewards, variances = self.impl.predict(x, action)
        if with_variance:
            return observations, rewards, variances
        return observations, rewards

    def generate(self, algo, transitions):
        """ Returns new transitions for data augmentation.

        Args:
            algo (d3rlpy.algos.base.AlgoBase): algorithm.
            transitions (list(d3rlpy.dataset.Transition)): list of transitions.

        Returns:
            list(d3rlpy.dataset.Transition): list of generated transitions.

        """
        # uniformly sample transitions
        init_transitions = []
        for i in np.random.randint(len(transitions), size=self.n_transitions):
            init_transitions.append(transitions[i])

        observation_shape = transitions[0].get_observation_shape()
        action_size = transitions[0].get_action_size()

        rets = []

        # rollout
        batch = TransitionMiniBatch(init_transitions)
        observations = batch.observations
        actions = algo.sample_action(observations)
        rewards = batch.rewards
        for _ in range(self.horizon):
            # predict next state
            next_observations, next_rewards = self.impl.generate(
                observations, actions)

            # sample policy action
            next_actions = algo.sample_action(next_observations)

            # append new transitions
            new_transitions = []
            for i in range(self.n_transitions):
                transition = Transition(observation_shape=observation_shape,
                                        action_size=action_size,
                                        observation=observations[i],
                                        action=actions[i],
                                        reward=float(rewards[i][0]),
                                        next_observation=next_observations[i],
                                        next_action=next_actions[i],
                                        next_reward=float(next_rewards[i][0]),
                                        terminal=0.0)
                new_transitions.append(transition)

            rets += new_transitions
            observations = next_observations.copy()
            actions = next_actions.copy()
            rewards = next_rewards.copy()

        return rets