Source code for d3rlpy.dataset.writers

from typing import Any, Protocol, Sequence, Union

import numpy as np

from ..types import NDArray, Observation, ObservationSequence
from .buffers import BufferProtocol
from .components import Episode, EpisodeBase, Signature
from .utils import get_dtype_from_observation, get_shape_from_observation

__all__ = [
    "WriterPreprocessProtocol",
    "BasicWriterPreprocess",
    "LastFrameWriterPreprocess",
    "ExperienceWriter",
]


[docs]class WriterPreprocessProtocol(Protocol): r"""Interface of WriterPreprocess."""
[docs] def process_observation(self, observation: Observation) -> Observation: r"""Processes observation. Args: observation: Observation. Returns: Processed observation. """ raise NotImplementedError
[docs] def process_action(self, action: NDArray) -> NDArray: r"""Processes action. Args: action: Action. Returns: Processed action. """ raise NotImplementedError
[docs] def process_reward(self, reward: NDArray) -> NDArray: r"""Processes reward. Args: reward: Reward. Returns: Processed reward. """ raise NotImplementedError
[docs]class BasicWriterPreprocess(WriterPreprocessProtocol): """Stanard data writer. This class implements identity preprocess. """
[docs] def process_observation(self, observation: Observation) -> Observation: return observation
[docs] def process_action(self, action: NDArray) -> NDArray: return action
[docs] def process_reward(self, reward: NDArray) -> NDArray: return reward
[docs]class LastFrameWriterPreprocess(BasicWriterPreprocess): """Data writer that writes the last channel of observation. This class is designed to be used with ``FrameStackTransitionPicker``. """
[docs] def process_observation(self, observation: Observation) -> Observation: if isinstance(observation, (list, tuple)): return [np.expand_dims(obs[-1], axis=0) for obs in observation] else: return np.expand_dims(observation[-1], axis=0)
class _ActiveEpisode(EpisodeBase): _preprocessor: WriterPreprocessProtocol _cache_size: int _cursor: int _observation_signature: Signature _action_signature: Signature _reward_signature: Signature _observations: Sequence[NDArray] _actions: NDArray _rewards: NDArray _terminated: bool _frozen: bool def __init__( self, preprocessor: WriterPreprocessProtocol, cache_size: int, observation_signature: Signature, action_signature: Signature, reward_signature: Signature, ) -> None: self._preprocessor = preprocessor self._cache_size = cache_size self._cursor = 0 shapes = observation_signature.shape dtypes = observation_signature.dtype self._observations = [ np.empty((cache_size, *shape), dtype=dtype) for shape, dtype in zip(shapes, dtypes) ] self._actions = np.empty( (cache_size, *action_signature.shape[0]), dtype=action_signature.dtype[0], ) self._rewards = np.empty( (cache_size, *reward_signature.shape[0]), dtype=reward_signature.dtype[0], ) self._terminated = False self._observation_signature = observation_signature self._action_signature = action_signature self._reward_signature = reward_signature self._frozen = True def append( self, observation: Observation, action: Union[int, NDArray], reward: Union[float, NDArray], ) -> None: assert self._frozen, "This episode is already shrinked." assert ( self._cursor < self._cache_size ), "episode length exceeds cache_size." if not isinstance(action, np.ndarray) or action.ndim == 0: action = np.array([action], dtype=self._action_signature.dtype[0]) if not isinstance(reward, np.ndarray) or reward.ndim == 0: reward = np.array([reward], dtype=self._reward_signature.dtype[0]) # preprocess observation = self._preprocessor.process_observation(observation) action = self._preprocessor.process_action(action) reward = self._preprocessor.process_reward(reward) if isinstance(observation, (list, tuple)): for i, obs in enumerate(observation): self._observations[i][self._cursor] = obs else: self._observations[0][self._cursor] = observation self._actions[self._cursor] = action self._rewards[self._cursor] = reward self._cursor += 1 def to_episode(self, terminated: bool) -> Episode: observations: ObservationSequence if len(self._observations) == 1: observations = self._observations[0][: self._cursor].copy() else: observations = [ obs[: self._cursor].copy() for obs in self._observations ] return Episode( observations=observations, actions=self._actions[: self._cursor].copy(), rewards=self._rewards[: self._cursor].copy(), terminated=terminated, ) def shrink(self, terminated: bool) -> None: episode = self.to_episode(terminated) if isinstance(episode.observations, np.ndarray): self._observations = [episode.observations] else: self._observations = episode.observations self._actions = episode.actions self._rewards = episode.rewards self._terminated = terminated self._frozen = True def size(self) -> int: return self._cursor @property def observations(self) -> ObservationSequence: if len(self._observations) == 1: return self._observations[0][: self._cursor] else: return [obs[: self._cursor] for obs in self._observations] @property def actions(self) -> NDArray: return self._actions[: self._cursor] @property def rewards(self) -> NDArray: return self._rewards[: self._cursor] @property def terminated(self) -> bool: return self._terminated @property def observation_signature(self) -> Signature: return self._observation_signature @property def action_signature(self) -> Signature: return self._action_signature @property def reward_signature(self) -> Signature: return self._reward_signature def compute_return(self) -> float: return float(np.sum(self.rewards[: self._cursor])) def serialize(self) -> dict[str, Any]: return { "observations": self.observations, "actions": self.actions, "rewards": self.rewards, "terminated": self.terminated, } @classmethod def deserialize(cls, serializedData: dict[str, Any]) -> "EpisodeBase": raise NotImplementedError("_ActiveEpisode cannot be deserialized.") def __len__(self) -> int: return self.size() @property def transition_count(self) -> int: return self.size() if self.terminated else self.size() - 1 class ExperienceWriter: """Experience writer. Args: buffer: Buffer. preprocessor: Writer preprocess. observation_signature: Signature of unprocessed observation. action_signature: Signature of unprocessed action. reward_signature: Signature of unprocessed reward. cache_size: Size of data in active episode. This needs to be larger than the maximum length of episodes. write_at_termination: Flag to write experiences to the buffer at the end of an episode all at once. """ _preprocessor: WriterPreprocessProtocol _buffer: BufferProtocol _cache_size: int _write_at_termination: bool _observation_signature: Signature _action_signature: Signature _reward_signature: Signature _active_episode: _ActiveEpisode _step: int def __init__( self, buffer: BufferProtocol, preprocessor: WriterPreprocessProtocol, observation_signature: Signature, action_signature: Signature, reward_signature: Signature, cache_size: int = 10000, write_at_termination: bool = False, ): self._buffer = buffer self._preprocessor = preprocessor self._cache_size = cache_size self._write_at_termination = write_at_termination # preprocessed signatures if len(observation_signature.dtype) == 1: processed_observation = preprocessor.process_observation( observation_signature.sample()[0] ) assert isinstance(processed_observation, np.ndarray) observation_signature = Signature( shape=[processed_observation.shape], dtype=[processed_observation.dtype], ) else: processed_observation = preprocessor.process_observation( observation_signature.sample() ) observation_shape = get_shape_from_observation( processed_observation ) assert isinstance(observation_shape[0], (list, tuple)) observation_dtype = get_dtype_from_observation( processed_observation ) assert isinstance(observation_dtype, (list, tuple)) observation_signature = Signature( shape=observation_shape, # type: ignore dtype=observation_dtype, ) processed_action = preprocessor.process_action( action_signature.sample()[0] ) action_shape: Sequence[int] if ( not isinstance(processed_action, np.ndarray) or processed_action.ndim == 0 ): action_shape = (1,) else: action_shape = processed_action.shape action_signature = Signature( shape=[action_shape], dtype=[processed_action.dtype], ) processed_reward = preprocessor.process_reward( reward_signature.sample()[0] ) reward_shape: Sequence[int] if ( not isinstance(processed_reward, np.ndarray) or processed_reward.ndim == 0 ): reward_shape = (1,) else: reward_shape = processed_reward.shape reward_signature = Signature( shape=[reward_shape], dtype=[processed_reward.dtype], ) self._observation_signature = observation_signature self._action_signature = action_signature self._reward_signature = reward_signature self._active_episode = _ActiveEpisode( preprocessor, cache_size=cache_size, observation_signature=observation_signature, action_signature=action_signature, reward_signature=reward_signature, ) def write( self, observation: Observation, action: Union[int, NDArray], reward: Union[float, NDArray], ) -> None: r"""Writes state tuple to buffer. Args: observation: Observation. action: Action. reward: Reward. """ self._active_episode.append(observation, action, reward) if ( not self._write_at_termination and self._active_episode.transition_count > 0 ): self._buffer.append( episode=self._active_episode, index=self._active_episode.transition_count - 1, ) def clip_episode(self, terminated: bool) -> None: r"""Clips the current episode. Args: terminated: Flag to represent environment termination. """ if self._active_episode.transition_count == 0: return if self._write_at_termination: for i in range(self._active_episode.transition_count): self._buffer.append(episode=self._active_episode, index=i) # shrink heap memory self._active_episode.shrink(terminated) # append terminal state if necessary if terminated: self._buffer.append( self._active_episode, self._active_episode.transition_count - 1, ) # prepare next active episode self._active_episode = _ActiveEpisode( self._preprocessor, cache_size=self._cache_size, observation_signature=self._observation_signature, action_signature=self._action_signature, reward_signature=self._reward_signature, )