from typing import Any, Protocol, Sequence, Union
import numpy as np
from ..types import NDArray, Observation, ObservationSequence
from .buffers import BufferProtocol
from .components import Episode, EpisodeBase, Signature
from .utils import get_dtype_from_observation, get_shape_from_observation
__all__ = [
"WriterPreprocessProtocol",
"BasicWriterPreprocess",
"LastFrameWriterPreprocess",
"ExperienceWriter",
]
[docs]class WriterPreprocessProtocol(Protocol):
r"""Interface of WriterPreprocess."""
[docs] def process_observation(self, observation: Observation) -> Observation:
r"""Processes observation.
Args:
observation: Observation.
Returns:
Processed observation.
"""
raise NotImplementedError
[docs] def process_action(self, action: NDArray) -> NDArray:
r"""Processes action.
Args:
action: Action.
Returns:
Processed action.
"""
raise NotImplementedError
[docs] def process_reward(self, reward: NDArray) -> NDArray:
r"""Processes reward.
Args:
reward: Reward.
Returns:
Processed reward.
"""
raise NotImplementedError
[docs]class BasicWriterPreprocess(WriterPreprocessProtocol):
"""Stanard data writer.
This class implements identity preprocess.
"""
[docs] def process_observation(self, observation: Observation) -> Observation:
return observation
[docs] def process_action(self, action: NDArray) -> NDArray:
return action
[docs] def process_reward(self, reward: NDArray) -> NDArray:
return reward
[docs]class LastFrameWriterPreprocess(BasicWriterPreprocess):
"""Data writer that writes the last channel of observation.
This class is designed to be used with ``FrameStackTransitionPicker``.
"""
[docs] def process_observation(self, observation: Observation) -> Observation:
if isinstance(observation, (list, tuple)):
return [np.expand_dims(obs[-1], axis=0) for obs in observation]
else:
return np.expand_dims(observation[-1], axis=0)
class _ActiveEpisode(EpisodeBase):
_preprocessor: WriterPreprocessProtocol
_cache_size: int
_cursor: int
_observation_signature: Signature
_action_signature: Signature
_reward_signature: Signature
_observations: Sequence[NDArray]
_actions: NDArray
_rewards: NDArray
_terminated: bool
_frozen: bool
def __init__(
self,
preprocessor: WriterPreprocessProtocol,
cache_size: int,
observation_signature: Signature,
action_signature: Signature,
reward_signature: Signature,
) -> None:
self._preprocessor = preprocessor
self._cache_size = cache_size
self._cursor = 0
shapes = observation_signature.shape
dtypes = observation_signature.dtype
self._observations = [
np.empty((cache_size, *shape), dtype=dtype)
for shape, dtype in zip(shapes, dtypes)
]
self._actions = np.empty(
(cache_size, *action_signature.shape[0]),
dtype=action_signature.dtype[0],
)
self._rewards = np.empty(
(cache_size, *reward_signature.shape[0]),
dtype=reward_signature.dtype[0],
)
self._terminated = False
self._observation_signature = observation_signature
self._action_signature = action_signature
self._reward_signature = reward_signature
self._frozen = True
def append(
self,
observation: Observation,
action: Union[int, NDArray],
reward: Union[float, NDArray],
) -> None:
assert self._frozen, "This episode is already shrinked."
assert (
self._cursor < self._cache_size
), "episode length exceeds cache_size."
if not isinstance(action, np.ndarray) or action.ndim == 0:
action = np.array([action], dtype=self._action_signature.dtype[0])
if not isinstance(reward, np.ndarray) or reward.ndim == 0:
reward = np.array([reward], dtype=self._reward_signature.dtype[0])
# preprocess
observation = self._preprocessor.process_observation(observation)
action = self._preprocessor.process_action(action)
reward = self._preprocessor.process_reward(reward)
if isinstance(observation, (list, tuple)):
for i, obs in enumerate(observation):
self._observations[i][self._cursor] = obs
else:
self._observations[0][self._cursor] = observation
self._actions[self._cursor] = action
self._rewards[self._cursor] = reward
self._cursor += 1
def to_episode(self, terminated: bool) -> Episode:
observations: ObservationSequence
if len(self._observations) == 1:
observations = self._observations[0][: self._cursor].copy()
else:
observations = [
obs[: self._cursor].copy() for obs in self._observations
]
return Episode(
observations=observations,
actions=self._actions[: self._cursor].copy(),
rewards=self._rewards[: self._cursor].copy(),
terminated=terminated,
)
def shrink(self, terminated: bool) -> None:
episode = self.to_episode(terminated)
if isinstance(episode.observations, np.ndarray):
self._observations = [episode.observations]
else:
self._observations = episode.observations
self._actions = episode.actions
self._rewards = episode.rewards
self._terminated = terminated
self._frozen = True
def size(self) -> int:
return self._cursor
@property
def observations(self) -> ObservationSequence:
if len(self._observations) == 1:
return self._observations[0][: self._cursor]
else:
return [obs[: self._cursor] for obs in self._observations]
@property
def actions(self) -> NDArray:
return self._actions[: self._cursor]
@property
def rewards(self) -> NDArray:
return self._rewards[: self._cursor]
@property
def terminated(self) -> bool:
return self._terminated
@property
def observation_signature(self) -> Signature:
return self._observation_signature
@property
def action_signature(self) -> Signature:
return self._action_signature
@property
def reward_signature(self) -> Signature:
return self._reward_signature
def compute_return(self) -> float:
return float(np.sum(self.rewards[: self._cursor]))
def serialize(self) -> dict[str, Any]:
return {
"observations": self.observations,
"actions": self.actions,
"rewards": self.rewards,
"terminated": self.terminated,
}
@classmethod
def deserialize(cls, serializedData: dict[str, Any]) -> "EpisodeBase":
raise NotImplementedError("_ActiveEpisode cannot be deserialized.")
def __len__(self) -> int:
return self.size()
@property
def transition_count(self) -> int:
return self.size() if self.terminated else self.size() - 1
class ExperienceWriter:
"""Experience writer.
Args:
buffer: Buffer.
preprocessor: Writer preprocess.
observation_signature: Signature of unprocessed observation.
action_signature: Signature of unprocessed action.
reward_signature: Signature of unprocessed reward.
cache_size: Size of data in active episode. This needs to be larger
than the maximum length of episodes.
write_at_termination: Flag to write experiences to the buffer at the
end of an episode all at once.
"""
_preprocessor: WriterPreprocessProtocol
_buffer: BufferProtocol
_cache_size: int
_write_at_termination: bool
_observation_signature: Signature
_action_signature: Signature
_reward_signature: Signature
_active_episode: _ActiveEpisode
_step: int
def __init__(
self,
buffer: BufferProtocol,
preprocessor: WriterPreprocessProtocol,
observation_signature: Signature,
action_signature: Signature,
reward_signature: Signature,
cache_size: int = 10000,
write_at_termination: bool = False,
):
self._buffer = buffer
self._preprocessor = preprocessor
self._cache_size = cache_size
self._write_at_termination = write_at_termination
# preprocessed signatures
if len(observation_signature.dtype) == 1:
processed_observation = preprocessor.process_observation(
observation_signature.sample()[0]
)
assert isinstance(processed_observation, np.ndarray)
observation_signature = Signature(
shape=[processed_observation.shape],
dtype=[processed_observation.dtype],
)
else:
processed_observation = preprocessor.process_observation(
observation_signature.sample()
)
observation_shape = get_shape_from_observation(
processed_observation
)
assert isinstance(observation_shape[0], (list, tuple))
observation_dtype = get_dtype_from_observation(
processed_observation
)
assert isinstance(observation_dtype, (list, tuple))
observation_signature = Signature(
shape=observation_shape, # type: ignore
dtype=observation_dtype,
)
processed_action = preprocessor.process_action(
action_signature.sample()[0]
)
action_shape: Sequence[int]
if (
not isinstance(processed_action, np.ndarray)
or processed_action.ndim == 0
):
action_shape = (1,)
else:
action_shape = processed_action.shape
action_signature = Signature(
shape=[action_shape],
dtype=[processed_action.dtype],
)
processed_reward = preprocessor.process_reward(
reward_signature.sample()[0]
)
reward_shape: Sequence[int]
if (
not isinstance(processed_reward, np.ndarray)
or processed_reward.ndim == 0
):
reward_shape = (1,)
else:
reward_shape = processed_reward.shape
reward_signature = Signature(
shape=[reward_shape],
dtype=[processed_reward.dtype],
)
self._observation_signature = observation_signature
self._action_signature = action_signature
self._reward_signature = reward_signature
self._active_episode = _ActiveEpisode(
preprocessor,
cache_size=cache_size,
observation_signature=observation_signature,
action_signature=action_signature,
reward_signature=reward_signature,
)
def write(
self,
observation: Observation,
action: Union[int, NDArray],
reward: Union[float, NDArray],
) -> None:
r"""Writes state tuple to buffer.
Args:
observation: Observation.
action: Action.
reward: Reward.
"""
self._active_episode.append(observation, action, reward)
if (
not self._write_at_termination
and self._active_episode.transition_count > 0
):
self._buffer.append(
episode=self._active_episode,
index=self._active_episode.transition_count - 1,
)
def clip_episode(self, terminated: bool) -> None:
r"""Clips the current episode.
Args:
terminated: Flag to represent environment termination.
"""
if self._active_episode.transition_count == 0:
return
if self._write_at_termination:
for i in range(self._active_episode.transition_count):
self._buffer.append(episode=self._active_episode, index=i)
# shrink heap memory
self._active_episode.shrink(terminated)
# append terminal state if necessary
if terminated:
self._buffer.append(
self._active_episode,
self._active_episode.transition_count - 1,
)
# prepare next active episode
self._active_episode = _ActiveEpisode(
self._preprocessor,
cache_size=self._cache_size,
observation_signature=self._observation_signature,
action_signature=self._action_signature,
reward_signature=self._reward_signature,
)