Source code for d3rlpy.dataset.replay_buffer

from typing import Any, BinaryIO, List, Optional, Sequence, Type, Union

import gym
import numpy as np

from .buffers import BufferProtocol, FIFOBuffer, InfiniteBuffer
from .components import (
    Episode,
    EpisodeBase,
    PartialTrajectory,
    Signature,
    Transition,
)
from .episode_generator import EpisodeGeneratorProtocol
from .io import dump, load
from .mini_batch import TrajectoryMiniBatch, TransitionMiniBatch
from .trajectory_slicers import BasicTrajectorySlicer, TrajectorySlicerProtocol
from .transition_pickers import BasicTransitionPicker, TransitionPickerProtocol
from .types import Observation
from .writers import (
    BasicWriterPreprocess,
    ExperienceWriter,
    WriterPreprocessProtocol,
)

__all__ = [
    "ReplayBuffer",
    "create_fifo_replay_buffer",
    "create_infinite_replay_buffer",
]


[docs]class ReplayBuffer: r"""Replay buffer for experience replay. This replay buffer implementation is used for both online and offline training in d3rlpy. To determine shapes of observations, actions and rewards, one of ``episodes``, ``env`` and signatures must be provided. .. code-block:: from d3rlpy.dataset import FIFOBuffer, ReplayBuffer, Signature buffer = FIFOBuffer(limit=1000000) # initialize with pre-collected episodes replay_buffer = ReplayBuffer(buffer=buffer, episodes=<episodes>) # initialize with Gym replay_buffer = ReplayBuffer(buffer=buffer, env=<env>) # initialize with manually specified signatures replay_buffer = ReplayBuffer( buffer=buffer, observation_signature=Signature(dtype=[<dtype>], shape=[<shape>]), action_signature=Signature(dtype=[<dtype>], shape=[<shape>]), reward_signature=Signature(dtype=[<dtype>], shape=[<shape>]), ) Args: buffer (d3rlpy.dataset.BufferProtocol): Buffer implementation. transition_picker (Optional[d3rlpy.dataset.TransitionPickerProtocol]): Transition picker implementation for Q-learning-based algorithms. If ``None`` is given, ``BasicTransitionPicker`` is used by default. trajectory_slicer (Optional[d3rlpy.dataset.TrajectorySlicerProtocol]): Trajectory slicer implementation for Transformer-based algorithms. If ``None`` is given, ``BasicTrajectorySlicer`` is used by default. writer_preprocessor (Optional[d3rlpy.dataset.WriterPreprocessProtocol]): Writer preprocessor implementation. If ``None`` is given, ``BasicWriterPreprocess`` is used by default. episodes (Optional[Sequence[d3rlpy.dataset.EpisodeBase]]): List of episodes to initialize replay buffer. env (Optional[gym.Env]): Gym environment to extract shapes of observations and action. observation_signature (Optional[d3rlpy.dataset.Signature]): Signature of observation. action_signature (Optional[d3rlpy.dataset.Signature]): Signature of action. reward_signature (Optional[d3rlpy.dataset.Signature]): Signature of reward. cache_size (int): Size of cache to record active episode history used for online training. ``cache_size`` needs to be greater than the maximum possible episode length. """ _buffer: BufferProtocol _transition_picker: TransitionPickerProtocol _trajectory_slicer: TrajectorySlicerProtocol _writer: ExperienceWriter _episodes: List[EpisodeBase] def __init__( self, buffer: BufferProtocol, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, episodes: Optional[Sequence[EpisodeBase]] = None, env: Optional[gym.Env[np.ndarray, Any]] = None, observation_signature: Optional[Signature] = None, action_signature: Optional[Signature] = None, reward_signature: Optional[Signature] = None, cache_size: int = 10000, ): transition_picker = transition_picker or BasicTransitionPicker() trajectory_slicer = trajectory_slicer or BasicTrajectorySlicer() writer_preprocessor = writer_preprocessor or BasicWriterPreprocess() if not ( observation_signature and action_signature and reward_signature ): if episodes: observation_signature = episodes[0].observation_signature action_signature = episodes[0].action_signature reward_signature = episodes[0].reward_signature elif env: observation_signature = Signature( dtype=[env.observation_space.dtype], shape=[env.observation_space.shape], # type: ignore ) action_signature = Signature( dtype=[env.action_space.dtype], shape=[env.action_space.shape], # type: ignore ) reward_signature = Signature( dtype=[np.dtype(np.float32)], shape=[[1]], ) else: raise ValueError( "Either episodes or env must be provided for signatures" ) self._buffer = buffer self._writer = ExperienceWriter( buffer, writer_preprocessor, observation_signature=observation_signature, action_signature=action_signature, reward_signature=reward_signature, cache_size=cache_size, ) self._transition_picker = transition_picker self._trajectory_slicer = trajectory_slicer if episodes: for episode in episodes: self.append_episode(episode)
[docs] def append( self, observation: Observation, action: Union[int, np.ndarray], reward: Union[float, np.ndarray], ) -> None: r"""Appends observation, action and reward to buffer. Args: observation: Observation. action: Action. reward: Reward. """ self._writer.write(observation, action, reward)
[docs] def append_episode(self, episode: EpisodeBase) -> None: r"""Appends episode to buffer. Args: episode: Episode. """ for i in range(episode.transition_count): self._buffer.append(episode, i)
[docs] def clip_episode(self, terminated: bool) -> None: r"""Clips current episode. Args: terminated: Flag to represent environmental termination. This flag should be ``False`` if the episode is terminated by timeout. """ self._writer.clip_episode(terminated)
[docs] def sample_transition(self) -> Transition: r"""Samples a transition. Returns: Transition. """ index = np.random.randint(self._buffer.transition_count) episode, transition_index = self._buffer[index] return self._transition_picker(episode, transition_index)
[docs] def sample_transition_batch(self, batch_size: int) -> TransitionMiniBatch: r"""Samples a mini-batch of transitions. Args: batch_size: Mini-batch size. Returns: Mini-batch. """ return TransitionMiniBatch.from_transitions( [self.sample_transition() for _ in range(batch_size)] )
[docs] def sample_trajectory(self, length: int) -> PartialTrajectory: r"""Samples a partial trajectory. Args: length: Length of partial trajectory. Returns: Partial trajectory. """ index = np.random.randint(self._buffer.transition_count) episode, transition_index = self._buffer[index] return self._trajectory_slicer(episode, transition_index, length)
[docs] def sample_trajectory_batch( self, batch_size: int, length: int ) -> TrajectoryMiniBatch: r"""Samples a mini-batch of partial trajectories. Args: batch_size: Mini-batch size. length: Length of partial trajectories. Returns: Mini-batch. """ return TrajectoryMiniBatch.from_partial_trajectories( [self.sample_trajectory(length) for _ in range(batch_size)] )
[docs] def dump(self, f: BinaryIO) -> None: """Dumps buffer data. .. code-block:: python with open('dataset.h5', 'wb') as f: replay_buffer.dump(f) Args: f: IO object to write to. """ dump(self._buffer.episodes, f)
[docs] @classmethod def from_episode_generator( cls, episode_generator: EpisodeGeneratorProtocol, buffer: BufferProtocol, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, ) -> "ReplayBuffer": """Builds ReplayBuffer from episode generator. Args: episode_generator: Episode generator implementation. buffer: Buffer implementation. transition_picker: Transition picker implementation for Q-learning-based algorithms. trajectory_slicer: Trajectory slicer implementation for Transformer-based algorithms. writer_preprocessor: Writer preprocessor implementation. Returns: Replay buffer. """ return cls( buffer, episodes=episode_generator(), transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, writer_preprocessor=writer_preprocessor, )
[docs] @classmethod def load( cls, f: BinaryIO, buffer: BufferProtocol, episode_cls: Type[EpisodeBase] = Episode, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, ) -> "ReplayBuffer": """Builds ReplayBuffer from dumped data. This method reconstructs replay buffer dumped by ``dump`` method. .. code-block:: python with open('dataset.h5', 'rb') as f: replay_buffer = ReplayBuffer.load(f, buffer) Args: f: IO object to read from. buffer: Buffer implementation. episode_cls: Eisode class used to reconstruct data. transition_picker: Transition picker implementation for Q-learning-based algorithms. trajectory_slicer: Trajectory slicer implementation for Transformer-based algorithms. writer_preprocessor: Writer preprocessor implementation. Returns: Replay buffer. """ return cls( buffer, episodes=load(episode_cls, f), transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, writer_preprocessor=writer_preprocessor, )
@property def episodes(self) -> Sequence[EpisodeBase]: """Returns sequence of episodes. Returns: Sequence of episodes. """ return self._buffer.episodes
[docs] def size(self) -> int: """Returns number of episodes. Returns: Number of episodes. """ return len(self._buffer.episodes)
@property def buffer(self) -> BufferProtocol: """Returns buffer. Returns: Buffer. """ return self._buffer @property def transition_count(self) -> int: """Returns number of transitions. Returns: Number of transitions. """ return self._buffer.transition_count @property def transition_picker(self) -> TransitionPickerProtocol: """Returns transition picker. Returns: Transition picker. """ return self._transition_picker @property def trajectory_slicer(self) -> TrajectorySlicerProtocol: """Returns trajectory slicer. Returns: Trajectory slicer. """ return self._trajectory_slicer
[docs]def create_fifo_replay_buffer( limit: int, episodes: Optional[Sequence[EpisodeBase]] = None, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, env: Optional[gym.Env[np.ndarray, Any]] = None, ) -> ReplayBuffer: """Builds FIFO replay buffer. This function is a shortcut alias to build replay buffer with ``FIFOBuffer``. Args: limit: Maximum capacity of FIFO buffer. episodes: List of episodes to initialize replay buffer. transition_picker: Transition picker implementation for Q-learning-based algorithms. If ``None`` is given, ``BasicTransitionPicker`` is used by default. trajectory_slicer: Trajectory slicer implementation for Transformer-based algorithms. If ``None`` is given, ``BasicTrajectorySlicer`` is used by default. writer_preprocessor: Writer preprocessor implementation. If ``None`` is given, ``BasicWriterPreprocess`` is used by default. env: Gym environment to extract shapes of observations and action. Returns: Replay buffer. """ buffer = FIFOBuffer(limit) return ReplayBuffer( buffer, episodes=episodes, transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, writer_preprocessor=writer_preprocessor, env=env, )
[docs]def create_infinite_replay_buffer( episodes: Optional[Sequence[EpisodeBase]] = None, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, writer_preprocessor: Optional[WriterPreprocessProtocol] = None, env: Optional[gym.Env[np.ndarray, Any]] = None, ) -> ReplayBuffer: """Builds infinite replay buffer. This function is a shortcut alias to build replay buffer with ``InfiniteBuffer``. Args: episodes: List of episodes to initialize replay buffer. transition_picker: Transition picker implementation for Q-learning-based algorithms. If ``None`` is given, ``BasicTransitionPicker`` is used by default. trajectory_slicer: Trajectory slicer implementation for Transformer-based algorithms. If ``None`` is given, ``BasicTrajectorySlicer`` is used by default. writer_preprocessor: Writer preprocessor implementation. If ``None`` is given, ``BasicWriterPreprocess`` is used by default. env: Gym environment to extract shapes of observations and action. Returns: Replay buffer. """ buffer = InfiniteBuffer() return ReplayBuffer( buffer, episodes=episodes, transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, writer_preprocessor=writer_preprocessor, env=env, )