import numpy as np
from abc import ABCMeta, abstractmethod
from collections import deque
from ..dataset import Transition, TransitionMiniBatch
from .utility import get_action_size_from_env
class Buffer(metaclass=ABCMeta):
@abstractmethod
def append(self, observation, action, reward, terminal):
""" Append observation, action, reward and terminal flag to buffer.
If the terminal flag is True, Monte-Carlo returns will be computed with
an entire episode and the whole transitions will be appended.
Args:
observation (numpy.ndarray): observation.
action (numpy.ndarray or int): action.
reward (float): reward.
terminal (bool or float): terminal flag.
"""
pass
@abstractmethod
def append_episode(self, episode):
""" Append Episode object to buffer.
Args:
episode (d3rlpy.dataset.Episode): episode.
"""
pass
@abstractmethod
def sample(self, batch_size, n_frames=1):
""" Returns sampled mini-batch of transitions.
If observation is image, you can stack arbitrary frames via
``n_frames``.
.. code-block:: python
buffer.observation_shape == (3, 84, 84)
# stack 4 frames
batch = buffer.sample(batch_size=32, n_frames=4)
batch.observations.shape == (32, 12, 84, 84)
Args:
batch_size (int): mini-batch size.
n_frames (int):
the number of frames to stack for image observation.
Returns:
d3rlpy.dataset.TransitionMiniBatch: mini-batch.
"""
pass
@abstractmethod
def size(self):
""" Returns the number of appended elements in buffer.
Returns:
int: the number of elements in buffer.
"""
pass
[docs]class ReplayBuffer(Buffer):
""" Standard Replay Buffer.
Args:
maxlen (int): the maximum number of data length.
env (gym.Env): gym-like environment to extract shape information.
episodes (list(d3rlpy.dataset.Episode)): list of episodes to
initialize buffer
Attributes:
prev_observation (numpy.ndarray): previously appended observation.
prev_action (numpy.ndarray or int): previously appended action.
prev_reward (float): previously appended reward.
prev_transition (d3rlpy.dataset.Transition):
previously appended transition.
transitions (collections.deque): list of transitions.
observation_shape (tuple): observation shape.
action_size (int): action size.
"""
def __init__(self, maxlen, env, episodes=None):
# temporary cache to hold transitions for an entire episode
self.prev_observation = None
self.prev_action = None
self.prev_reward = None
self.prev_transition = None
self.transitions = deque(maxlen=maxlen)
# extract shape information
self.observation_shape = env.observation_space.shape
self.action_size = get_action_size_from_env(env)
# add initial transitions
if episodes:
for episode in episodes:
self.append_episode(episode)
[docs] def append(self, observation, action, reward, terminal):
# validation
assert observation.shape == self.observation_shape
if isinstance(action, np.ndarray):
assert action.shape[0] == self.action_size
else:
action = int(action)
assert action < self.action_size
# create Transition object
if self.prev_observation is not None:
if isinstance(terminal, bool):
terminal = 1.0 if terminal else 0.0
transition = Transition(observation_shape=self.observation_shape,
action_size=self.action_size,
observation=self.prev_observation,
action=self.prev_action,
reward=self.prev_reward,
next_observation=observation,
next_action=action,
next_reward=reward,
terminal=terminal,
prev_transition=self.prev_transition)
if self.prev_transition:
self.prev_transition.next_transition = transition
self.transitions.append(transition)
self.prev_transition = transition
self.prev_observation = observation
self.prev_action = action
self.prev_reward = reward
if terminal:
self.prev_observation = None
self.prev_action = None
self.prev_reward = None
self.prev_transition = None
[docs] def append_episode(self, episode):
assert episode.get_observation_shape() == self.observation_shape
assert episode.get_action_size() == self.action_size
for transition in episode.transitions:
self.transitions.append(transition)
[docs] def sample(self, batch_size, n_frames=1):
indices = np.random.randint(self.size(), size=batch_size)
transitions = [self.transitions[index] for index in indices]
return TransitionMiniBatch(transitions, n_frames)
[docs] def size(self):
return len(self.transitions)
[docs] def __len__(self):
return self.size()