Source code for d3rlpy.dataset.writers

from typing import Any, Dict, Sequence, Union

import numpy as np
from typing_extensions import Protocol

from ..types import NDArray, Observation, ObservationSequence
from .buffers import BufferProtocol
from .components import Episode, EpisodeBase, Signature
from .utils import get_dtype_from_observation, get_shape_from_observation

__all__ = [
    "WriterPreprocessProtocol",
    "BasicWriterPreprocess",
    "LastFrameWriterPreprocess",
    "ExperienceWriter",
]


[docs]class WriterPreprocessProtocol(Protocol): r"""Interface of WriterPreprocess."""
[docs] def process_observation(self, observation: Observation) -> Observation: r"""Processes observation. Args: observation: Observation. Returns: Processed observation. """ raise NotImplementedError
[docs] def process_action(self, action: NDArray) -> NDArray: r"""Processes action. Args: action: Action. Returns: Processed action. """ raise NotImplementedError
[docs] def process_reward(self, reward: NDArray) -> NDArray: r"""Processes reward. Args: reward: Reward. Returns: Processed reward. """ raise NotImplementedError
[docs]class BasicWriterPreprocess(WriterPreprocessProtocol): """Stanard data writer. This class implements identity preprocess. """
[docs] def process_observation(self, observation: Observation) -> Observation: return observation
[docs] def process_action(self, action: NDArray) -> NDArray: return action
[docs] def process_reward(self, reward: NDArray) -> NDArray: return reward
[docs]class LastFrameWriterPreprocess(BasicWriterPreprocess): """Data writer that writes the last channel of observation. This class is designed to be used with ``FrameStackTransitionPicker``. """
[docs] def process_observation(self, observation: Observation) -> Observation: if isinstance(observation, (list, tuple)): return [np.expand_dims(obs[-1], axis=0) for obs in observation] else: return np.expand_dims(observation[-1], axis=0)
class _ActiveEpisode(EpisodeBase): _preprocessor: WriterPreprocessProtocol _cache_size: int _cursor: int _observation_signature: Signature _action_signature: Signature _reward_signature: Signature _observations: Sequence[NDArray] _actions: NDArray _rewards: NDArray _terminated: bool _frozen: bool def __init__( self, preprocessor: WriterPreprocessProtocol, cache_size: int, observation_signature: Signature, action_signature: Signature, reward_signature: Signature, ) -> None: self._preprocessor = preprocessor self._cache_size = cache_size self._cursor = 0 shapes = observation_signature.shape dtypes = observation_signature.dtype self._observations = [ np.empty((cache_size, *shape), dtype=dtype) for shape, dtype in zip(shapes, dtypes) ] self._actions = np.empty( (cache_size, *action_signature.shape[0]), dtype=action_signature.dtype[0], ) self._rewards = np.empty( (cache_size, *reward_signature.shape[0]), dtype=reward_signature.dtype[0], ) self._terminated = False self._observation_signature = observation_signature self._action_signature = action_signature self._reward_signature = reward_signature self._frozen = True def append( self, observation: Observation, action: Union[int, NDArray], reward: Union[float, NDArray], ) -> None: assert self._frozen, "This episode is already shrinked." assert ( self._cursor < self._cache_size ), "episode length exceeds cache_size." if not isinstance(action, np.ndarray) or action.ndim == 0: action = np.array([action], dtype=self._action_signature.dtype[0]) if not isinstance(reward, np.ndarray) or reward.ndim == 0: reward = np.array([reward], dtype=self._reward_signature.dtype[0]) # preprocess observation = self._preprocessor.process_observation(observation) action = self._preprocessor.process_action(action) reward = self._preprocessor.process_reward(reward) if isinstance(observation, (list, tuple)): for i, obs in enumerate(observation): self._observations[i][self._cursor] = obs else: self._observations[0][self._cursor] = observation self._actions[self._cursor] = action self._rewards[self._cursor] = reward self._cursor += 1 def to_episode(self, terminated: bool) -> Episode: observations: ObservationSequence if len(self._observations) == 1: observations = self._observations[0][: self._cursor].copy() else: observations = [ obs[: self._cursor].copy() for obs in self._observations ] return Episode( observations=observations, actions=self._actions[: self._cursor].copy(), rewards=self._rewards[: self._cursor].copy(), terminated=terminated, ) def shrink(self, terminated: bool) -> None: episode = self.to_episode(terminated) if isinstance(episode.observations, np.ndarray): self._observations = [episode.observations] else: self._observations = episode.observations self._actions = episode.actions self._rewards = episode.rewards self._terminated = terminated self._frozen = True def size(self) -> int: return self._cursor @property def observations(self) -> ObservationSequence: if len(self._observations) == 1: return self._observations[0][: self._cursor] else: return [obs[: self._cursor] for obs in self._observations] @property def actions(self) -> NDArray: return self._actions[: self._cursor] @property def rewards(self) -> NDArray: return self._rewards[: self._cursor] @property def terminated(self) -> bool: return self._terminated @property def observation_signature(self) -> Signature: return self._observation_signature @property def action_signature(self) -> Signature: return self._action_signature @property def reward_signature(self) -> Signature: return self._reward_signature def compute_return(self) -> float: return float(np.sum(self.rewards[: self._cursor])) def serialize(self) -> Dict[str, Any]: return { "observations": self.observations, "actions": self.actions, "rewards": self.rewards, "terminated": self.terminated, } @classmethod def deserialize(cls, serializedData: Dict[str, Any]) -> "EpisodeBase": raise NotImplementedError("_ActiveEpisode cannot be deserialized.") def __len__(self) -> int: return self.size() @property def transition_count(self) -> int: return self.size() if self.terminated else self.size() - 1 class ExperienceWriter: """Experience writer. Args: buffer: Buffer. preprocessor: Writer preprocess. observation_signature: Signature of unprocessed observation. action_signature: Signature of unprocessed action. reward_signature: Signature of unprocessed reward. cache_size: Size of data in active episode. This needs to be larger than the maximum length of episodes. write_at_termination: Flag to write experiences to the buffer at the end of an episode all at once. """ _preprocessor: WriterPreprocessProtocol _buffer: BufferProtocol _cache_size: int _write_at_termination: bool _observation_signature: Signature _action_signature: Signature _reward_signature: Signature _active_episode: _ActiveEpisode _step: int def __init__( self, buffer: BufferProtocol, preprocessor: WriterPreprocessProtocol, observation_signature: Signature, action_signature: Signature, reward_signature: Signature, cache_size: int = 10000, write_at_termination: bool = False, ): self._buffer = buffer self._preprocessor = preprocessor self._cache_size = cache_size self._write_at_termination = write_at_termination # preprocessed signatures if len(observation_signature.dtype) == 1: processed_observation = preprocessor.process_observation( observation_signature.sample()[0] ) assert isinstance(processed_observation, np.ndarray) observation_signature = Signature( shape=[processed_observation.shape], dtype=[processed_observation.dtype], ) else: processed_observation = preprocessor.process_observation( observation_signature.sample() ) observation_shape = get_shape_from_observation( processed_observation ) assert isinstance(observation_shape[0], (list, tuple)) observation_dtype = get_dtype_from_observation( processed_observation ) assert isinstance(observation_dtype, (list, tuple)) observation_signature = Signature( shape=observation_shape, # type: ignore dtype=observation_dtype, ) processed_action = preprocessor.process_action( action_signature.sample()[0] ) action_shape: Sequence[int] if ( not isinstance(processed_action, np.ndarray) or processed_action.ndim == 0 ): action_shape = (1,) else: action_shape = processed_action.shape action_signature = Signature( shape=[action_shape], dtype=[processed_action.dtype], ) processed_reward = preprocessor.process_reward( reward_signature.sample()[0] ) reward_shape: Sequence[int] if ( not isinstance(processed_reward, np.ndarray) or processed_reward.ndim == 0 ): reward_shape = (1,) else: reward_shape = processed_reward.shape reward_signature = Signature( shape=[reward_shape], dtype=[processed_reward.dtype], ) self._observation_signature = observation_signature self._action_signature = action_signature self._reward_signature = reward_signature self._active_episode = _ActiveEpisode( preprocessor, cache_size=cache_size, observation_signature=observation_signature, action_signature=action_signature, reward_signature=reward_signature, ) def write( self, observation: Observation, action: Union[int, NDArray], reward: Union[float, NDArray], ) -> None: r"""Writes state tuple to buffer. Args: observation: Observation. action: Action. reward: Reward. """ self._active_episode.append(observation, action, reward) if ( not self._write_at_termination and self._active_episode.transition_count > 0 ): self._buffer.append( episode=self._active_episode, index=self._active_episode.transition_count - 1, ) def clip_episode(self, terminated: bool) -> None: r"""Clips the current episode. Args: terminated: Flag to represent environment termination. """ if self._active_episode.transition_count == 0: return if self._write_at_termination: for i in range(self._active_episode.transition_count): self._buffer.append(episode=self._active_episode, index=i) # shrink heap memory self._active_episode.shrink(terminated) # append terminal state if necessary if terminated: self._buffer.append( self._active_episode, self._active_episode.transition_count - 1, ) # prepare next active episode self._active_episode = _ActiveEpisode( self._preprocessor, cache_size=self._cache_size, observation_signature=self._observation_signature, action_signature=self._action_signature, reward_signature=self._reward_signature, )