Source code for d3rlpy.online.buffers

from abc import ABCMeta, abstractmethod
from typing import (
    Callable,
    Generic,
    Iterator,
    List,
    Optional,
    TypeVar,
    Sequence,
    cast,
)

import numpy as np
import gym

from ..envs import BatchEnv
from ..dataset import (
    Episode,
    MDPDataset,
    Transition,
    TransitionMiniBatch,
    trace_back_and_clear,
)
from .utility import get_action_size_from_env

T = TypeVar("T")


class FIFOQueue(Generic[T]):
    """Simple FIFO queue implementation.

    Random access of this queue object is O(1).

    """

    _maxlen: int
    _drop_callback: Optional[Callable[[T], None]]
    _buffer: List[Optional[T]]
    _cursor: int
    _size: int
    _index: int

    def __init__(
        self, maxlen: int, drop_callback: Optional[Callable[[T], None]] = None
    ):
        self._maxlen = maxlen
        self._drop_callback = drop_callback
        self._buffer = [None for _ in range(maxlen)]
        self._cursor = 0
        self._size = 0
        self._index = 0

    def append(self, item: T) -> None:
        # call drop callback if necessary
        cur_item = self._buffer[self._cursor]
        if cur_item and self._drop_callback:
            self._drop_callback(cur_item)

        self._buffer[self._cursor] = item

        # increment cursor
        self._cursor += 1
        if self._cursor == self._maxlen:
            self._cursor = 0
        self._size = min(self._size + 1, self._maxlen)

    def __getitem__(self, index: int) -> T:
        assert index < self._size

        # handle negative indexing
        if index < 0:
            index = self._size + index

        item = self._buffer[index]
        assert item is not None
        return item

    def __len__(self) -> int:
        return self._size

    def __iter__(self) -> Iterator[T]:
        self._index = 0
        return self

    def __next__(self) -> T:
        if self._index >= self._size:
            raise StopIteration
        item = self._buffer[self._index]
        assert item is not None
        self._index += 1
        return item


class _Buffer(metaclass=ABCMeta):

    _transitions: FIFOQueue[Transition]
    _observation_shape: Sequence[int]
    _action_size: int
    _create_mask: bool
    _mask_size: int

    def __init__(
        self,
        maxlen: int,
        env: Optional[gym.Env] = None,
        episodes: Optional[List[Episode]] = None,
        create_mask: bool = False,
        mask_size: int = 1,
    ):
        def drop_callback(transition: Transition) -> None:
            # remove links when dropping the last transition
            if transition.next_transition is None:
                trace_back_and_clear(transition)

        self._transitions = FIFOQueue(maxlen, drop_callback)

        # extract shape information
        if env:
            observation_shape = env.observation_space.shape
            action_size = get_action_size_from_env(env)
        elif episodes:
            observation_shape = episodes[0].get_observation_shape()
            action_size = episodes[0].get_action_size()
        else:
            raise ValueError("env or episodes are required to determine shape.")

        self._observation_shape = observation_shape
        self._action_size = action_size
        self._create_mask = create_mask
        self._mask_size = mask_size

        # add initial transitions
        if episodes:
            for episode in episodes:
                self.append_episode(episode)

    def append_episode(self, episode: Episode) -> None:
        """Append Episode object to buffer.

        Args:
            episode: episode.

        """
        assert episode.get_observation_shape() == self._observation_shape
        assert episode.get_action_size() == self._action_size
        for transition in episode.transitions:
            self._transitions.append(transition)
            # add mask if necessary
            if self._create_mask and transition.mask is None:
                transition.mask = np.random.randint(2, size=self._mask_size)

    @abstractmethod
    def sample(
        self,
        batch_size: int,
        n_frames: int = 1,
        n_steps: int = 1,
        gamma: float = 0.99,
    ) -> TransitionMiniBatch:
        """Returns sampled mini-batch of transitions.

        If observation is image, you can stack arbitrary frames via
        ``n_frames``.

        .. code-block:: python

            buffer.observation_shape == (3, 84, 84)

            # stack 4 frames
            batch = buffer.sample(batch_size=32, n_frames=4)

            batch.observations.shape == (32, 12, 84, 84)

        Args:
            batch_size: mini-batch size.
            n_frames: the number of frames to stack for image observation.
            n_steps: the number of steps before the next observation.
            gamma: discount factor used in N-step return calculation.

        Returns:
            mini-batch.

        """

    def size(self) -> int:
        """Returns the number of appended elements in buffer.

        Returns:
            the number of elements in buffer.

        """
        return len(self._transitions)

    def to_mdp_dataset(self) -> MDPDataset:
        """Convert replay data into static dataset.

        The length of the dataset can be longer than the length of the replay
        buffer because this conversion is done by tracing ``Transition``
        objects.

        Returns:
            MDPDataset object.

        """
        # get the last transitions
        tail_transitions: List[Transition] = []
        for transition in self._transitions:
            if transition.next_transition is None:
                tail_transitions.append(transition)

        observations = []
        actions = []
        rewards = []
        terminals = []
        episode_terminals = []
        for transition in tail_transitions:

            # trace transition to the beginning
            episode_transitions: List[Transition] = []
            while True:
                episode_transitions.append(transition)
                if transition.prev_transition is None:
                    break
                transition = transition.prev_transition
            episode_transitions.reverse()

            # stack data
            for episode_transition in episode_transitions:
                observations.append(episode_transition.observation)
                actions.append(episode_transition.action)
                rewards.append(episode_transition.reward)
                terminals.append(0.0)
                episode_terminals.append(0.0)
            observations.append(episode_transitions[-1].next_observation)
            actions.append(episode_transitions[-1].next_action)
            rewards.append(episode_transitions[-1].next_reward)
            terminals.append(episode_transitions[-1].terminal)
            episode_terminals.append(1.0)

        if len(self._observation_shape) == 3:
            observations = np.asarray(observations, dtype=np.uint8)
        else:
            observations = np.asarray(observations, dtype=np.float32)

        return MDPDataset(
            observations=observations,
            actions=actions,
            rewards=rewards,
            terminals=terminals,
            episode_terminals=episode_terminals,
            create_mask=self._create_mask,
            mask_size=self._mask_size,
        )

    def __len__(self) -> int:
        return self.size()

    @property
    def transitions(self) -> FIFOQueue[Transition]:
        """Returns a FIFO queue of transitions.

        Returns:
            d3rlpy.online.buffers.FIFOQueue: FIFO queue of transitions.

        """
        return self._transitions


class Buffer(_Buffer):
    @abstractmethod
    def append(
        self,
        observation: np.ndarray,
        action: np.ndarray,
        reward: float,
        terminal: float,
        clip_episode: Optional[bool] = None,
    ) -> None:
        """Append observation, action, reward and terminal flag to buffer.

        If the terminal flag is True, Monte-Carlo returns will be computed with
        an entire episode and the whole transitions will be appended.

        Args:
            observation: observation.
            action: action.
            reward: reward.
            terminal: terminal flag.
            clip_episode: flag to clip the current episode. If ``None``, the
                episode is clipped based on ``terminal``.

        """


class BatchBuffer(_Buffer):
    @abstractmethod
    def append(
        self,
        observations: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        terminals: np.ndarray,
        clip_episodes: Optional[np.ndarray] = None,
    ) -> None:
        """Append observation, action, reward and terminal flag to buffer.

        If the terminal flag is True, Monte-Carlo returns will be computed with
        an entire episode and the whole transitions will be appended.

        Args:
            observations: observation.
            actions: action.
            rewards: reward.
            terminals: terminal flag.
            clip_episodes: flag to clip the current episode. If ``None``, the
                episode is clipped based on ``terminal``.

        """


class BasicSampleMixin:

    _transitions: FIFOQueue[Transition]

    def sample(
        self,
        batch_size: int,
        n_frames: int = 1,
        n_steps: int = 1,
        gamma: float = 0.99,
    ) -> TransitionMiniBatch:
        indices = np.random.choice(len(self._transitions), batch_size)
        transitions = [self._transitions[index] for index in indices]
        batch = TransitionMiniBatch(transitions, n_frames, n_steps, gamma)
        return batch


[docs]class ReplayBuffer(BasicSampleMixin, Buffer): """Standard Replay Buffer. Args: maxlen (int): the maximum number of data length. env (gym.Env): gym-like environment to extract shape information. episodes (list(d3rlpy.dataset.Episode)): list of episodes to initialize buffer. create_mask (bool): flag to create bootstrapping mask. mask_size (int): ensemble size for binary mask. """ _prev_observation: Optional[np.ndarray] _prev_action: Optional[np.ndarray] _prev_reward: float _prev_transition: Optional[Transition] def __init__( self, maxlen: int, env: Optional[gym.Env] = None, episodes: Optional[List[Episode]] = None, create_mask: bool = False, mask_size: int = 1, ): super().__init__(maxlen, env, episodes, create_mask, mask_size) self._prev_observation = None self._prev_action = None self._prev_reward = 0.0 self._prev_transition = None
[docs] def append( self, observation: np.ndarray, action: np.ndarray, reward: float, terminal: float, clip_episode: Optional[bool] = None, ) -> None: # if None, use terminal if clip_episode is None: clip_episode = bool(terminal) # validation assert observation.shape == self._observation_shape if isinstance(action, np.ndarray): assert action.shape[0] == self._action_size else: action = int(action) assert action < self._action_size # not allow terminal=True and clip_episode=False assert not (terminal and not clip_episode) # create Transition object if self._prev_observation is not None: if isinstance(terminal, bool): terminal = 1.0 if terminal else 0.0 # create binary mask if self._create_mask: mask = np.random.randint(2, size=self._mask_size) else: mask = None transition = Transition( observation_shape=self._observation_shape, action_size=self._action_size, observation=self._prev_observation, action=self._prev_action, reward=self._prev_reward, next_observation=observation, next_action=action, next_reward=reward, terminal=terminal, mask=mask, prev_transition=self._prev_transition, ) if self._prev_transition: self._prev_transition.next_transition = transition self._transitions.append(transition) self._prev_transition = transition self._prev_observation = observation self._prev_action = action self._prev_reward = reward if clip_episode: self._prev_observation = None self._prev_action = None self._prev_reward = 0.0 self._prev_transition = None
[docs]class BatchReplayBuffer(BasicSampleMixin, BatchBuffer): """Standard Replay Buffer for batch training. Args: maxlen (int): the maximum number of data length. n_envs (int): the number of environments. env (gym.Env): gym-like environment to extract shape information. episodes (list(d3rlpy.dataset.Episode)): list of episodes to initialize buffer create_mask (bool): flag to create bootstrapping mask. mask_size (int): ensemble size for binary mask. """ _n_envs: int _prev_observations: List[Optional[np.ndarray]] _prev_actions: List[Optional[np.ndarray]] _prev_rewards: List[Optional[np.ndarray]] _prev_transitions: List[Optional[Transition]] def __init__( self, maxlen: int, env: BatchEnv, episodes: Optional[List[Episode]] = None, create_mask: bool = False, mask_size: int = 1, ): super().__init__(maxlen, env, episodes, create_mask, mask_size) self._n_envs = len(env) self._prev_observations = [None for _ in range(len(env))] self._prev_actions = [None for _ in range(len(env))] self._prev_rewards = [None for _ in range(len(env))] self._prev_transitions = [None for _ in range(len(env))]
[docs] def append( self, observations: np.ndarray, actions: np.ndarray, rewards: np.ndarray, terminals: np.ndarray, clip_episodes: Optional[np.ndarray] = None, ) -> None: # if None, use terminal if clip_episodes is None: clip_episodes = terminals # validation assert observations.shape == (self._n_envs, *self._observation_shape) if actions.ndim == 2: assert actions.shape == (self._n_envs, self._action_size) else: assert actions.shape == (self._n_envs,) assert rewards.shape == (self._n_envs,) assert terminals.shape == (self._n_envs,) # not allow terminal=True and clip_episode=False assert np.all(terminals - clip_episodes < 1) # create Transition objects for i in range(self._n_envs): if self._prev_observations[i] is not None: prev_observation = self._prev_observations[i] prev_action = self._prev_actions[i] prev_reward = cast(np.ndarray, self._prev_rewards[i]) prev_transition = self._prev_transitions[i] # create binary mask if self._create_mask: mask = np.random.randint(2, size=self._mask_size) else: mask = None transition = Transition( observation_shape=self._observation_shape, action_size=self._action_size, observation=prev_observation, action=prev_action, reward=float(prev_reward), next_observation=observations[i], next_action=actions[i], next_reward=float(rewards[i]), terminal=float(terminals[i]), mask=mask, prev_transition=prev_transition, ) if prev_transition: prev_transition.next_transition = transition self._transitions.append(transition) self._prev_transitions[i] = transition self._prev_observations[i] = observations[i] self._prev_actions[i] = actions[i] self._prev_rewards[i] = rewards[i] if clip_episodes[i]: self._prev_observations[i] = None self._prev_actions[i] = None self._prev_rewards[i] = None self._prev_transitions[i] = None