d3rlpy.dataset.TransitionMiniBatch

class d3rlpy.dataset.TransitionMiniBatch

mini-batch of Transition objects.

This class is designed to hold d3rlpy.dataset.Transition objects for being passed to algorithms during fitting.

If the observation is image, you can stack arbitrary frames via n_frames.

transition.observation.shape == (3, 84, 84)

batch_size = len(transitions)

# stack 4 frames
batch = TransitionMiniBatch(transitions, n_frames=4)

# 4 frames x 3 channels
batch.observations.shape == (batch_size, 12, 84, 84)

This is implemented by tracing previous transitions through prev_transition property.

Parameters:
  • transitions (list(d3rlpy.dataset.Transition)) – mini-batch of transitions.
  • n_frames (int) – the number of frames to stack for image observation.
  • n_steps (int) – length of N-step sampling.
  • gamma (float) – discount factor for N-step calculation.

Methods

size()

Returns size of mini-batch.

Returns:mini-batch size.
Return type:int

Attributes

actions

Returns mini-batch of actions at t.

Returns:actions at t.
Return type:numpy.ndarray
n_steps

Returns mini-batch of the number of steps before next observations.

This will always include only ones if n_steps=1. If n_steps is bigger than 1. the values will depend on its episode length.

Returns:the number of steps before next observations.
Return type:numpy.ndarray
next_actions

Returns mini-batch of actions at t+n.

Returns:actions at t+n.
Return type:numpy.ndarray
next_observations

Returns mini-batch of observations at t+n.

Returns:observations at t+n.
Return type:numpy.ndarray or torch.Tensor
next_rewards

Returns mini-batch of rewards at t+n.

Returns:rewards at t+n.
Return type:numpy.ndarray
observations

Returns mini-batch of observations at t.

Returns:observations at t.
Return type:numpy.ndarray or torch.Tensor
rewards

Returns mini-batch of rewards at t.

Returns:rewards at t.
Return type:numpy.ndarray
terminals

Returns mini-batch of terminal flags at t+n.

Returns:terminal flags at t+n.
Return type:numpy.ndarray
transitions

Returns transitions.

Returns:list of transitions.
Return type:d3rlpy.dataset.Transition