Source code for d3rlpy.algos.transformer.base

import dataclasses
from abc import abstractmethod
from collections import defaultdict, deque
from typing import Callable, Deque, Dict, Generic, Optional, TypeVar, Union

import numpy as np
import torch
from tqdm.auto import tqdm
from typing_extensions import Self

from ...base import ImplBase, LearnableBase, LearnableConfig, save_config
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...dataset import ReplayBuffer, TrajectoryMiniBatch
from ...envs import GymEnv
from ...logging import (
    LOG,
    D3RLPyLogger,
    FileAdapterFactory,
    LoggerAdapterFactory,
)
from ...metrics import evaluate_transformer_with_environment
from ...torch_utility import TorchTrajectoryMiniBatch, train_api
from ...types import NDArray, Observation
from ..utility import (
    assert_action_space_with_dataset,
    build_scalers_with_trajectory_slicer,
)
from .action_samplers import (
    IdentityTransformerActionSampler,
    SoftmaxTransformerActionSampler,
    TransformerActionSampler,
)
from .inputs import TorchTransformerInput, TransformerInput

__all__ = [
    "TransformerAlgoImplBase",
    "StatefulTransformerWrapper",
    "TransformerConfig",
    "TransformerAlgoBase",
]


class TransformerAlgoImplBase(ImplBase):
    @abstractmethod
    def predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
        ...

    @train_api
    def update(
        self, batch: TorchTrajectoryMiniBatch, grad_step: int
    ) -> Dict[str, float]:
        return self.inner_update(batch, grad_step)

    @abstractmethod
    def inner_update(
        self, batch: TorchTrajectoryMiniBatch, grad_step: int
    ) -> Dict[str, float]:
        pass


@dataclasses.dataclass()
class TransformerConfig(LearnableConfig):
    context_size: int = 20
    max_timestep: int = 1000


TTransformerImpl = TypeVar("TTransformerImpl", bound=TransformerAlgoImplBase)
TTransformerConfig = TypeVar("TTransformerConfig", bound=TransformerConfig)


class StatefulTransformerWrapper(Generic[TTransformerImpl, TTransformerConfig]):
    r"""A stateful wrapper for inference of Transformer-based algorithms.

    This wrapper class provides a similar interface of Q-learning-based
    algoritms, which is especially useful when you evaluate Transformer-based
    algorithms such as Decision Transformer.

    .. code-block:: python

        from d3rlpy.algos import DecisionTransformerConfig
        from d3rlpy.algos import StatefulTransformerWrapper

        dt = DecisionTransformerConfig().create()
        dt.create_impl(<observation_shape>, <action_size>)
        # initialize wrapper with a target return of 1000
        wrapper = StatefulTransformerWrapper(dt, target_return=1000)
        # shortcut is also available
        wrapper = dt.as_stateful_wrapper(target_return=1000)

        # predict next action to achieve the return of 1000 in the end
        action = wrapper.predict(<observation>, <reward>)

        # clear stateful information
        wrapper.reset()

    Args:
        algo (TransformerAlgoBase): Transformer-based algorithm.
        target_return (float): Target return.
        action_sampler (d3rlpy.algos.TransformerActionSampler): Action sampler.
    """
    _algo: "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]"
    _target_return: float
    _action_sampler: TransformerActionSampler
    _return_rest: float
    _observations: Deque[Observation]
    _actions: Deque[Union[NDArray, int]]
    _rewards: Deque[float]
    _returns_to_go: Deque[float]
    _timesteps: Deque[int]
    _timestep: int

    def __init__(
        self,
        algo: "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]",
        target_return: float,
        action_sampler: TransformerActionSampler,
    ):
        assert algo.impl, "algo must be built."
        self._algo = algo
        self._target_return = target_return
        self._action_sampler = action_sampler
        self._return_rest = target_return

        context_size = algo.config.context_size
        self._observations = deque([], maxlen=context_size)
        self._actions = deque([self._get_pad_action()], maxlen=context_size)
        self._rewards = deque([], maxlen=context_size)
        self._returns_to_go = deque([], maxlen=context_size)
        self._timesteps = deque([], maxlen=context_size)
        self._timestep = 1

    def predict(self, x: Observation, reward: float) -> Union[NDArray, int]:
        r"""Returns action.

        Args:
            x: Observation.
            reward: Last reward.

        Returns:
            Action.
        """
        self._observations.append(x)
        self._rewards.append(reward)
        self._returns_to_go.append(self._return_rest - reward)
        self._timesteps.append(self._timestep)
        inpt = TransformerInput(
            observations=np.array(self._observations),
            actions=np.array(self._actions),
            rewards=np.array(self._rewards).reshape((-1, 1)),
            returns_to_go=np.array(self._returns_to_go).reshape((-1, 1)),
            timesteps=np.array(self._timesteps),
        )
        action = self._action_sampler(self._algo.predict(inpt))
        self._actions[-1] = action
        self._actions.append(self._get_pad_action())
        self._timestep = min(self._timestep + 1, self._algo.config.max_timestep)
        self._return_rest -= reward
        return action

    def reset(self) -> None:
        """Clears stateful information."""
        self._observations.clear()
        self._actions.clear()
        self._rewards.clear()
        self._returns_to_go.clear()
        self._timesteps.clear()
        self._actions.append(self._get_pad_action())
        self._timestep = 1
        self._return_rest = self._target_return

    @property
    def algo(
        self,
    ) -> "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]":
        return self._algo

    def _get_pad_action(self) -> Union[int, NDArray]:
        assert self._algo.impl
        pad_action: Union[int, NDArray]
        if self._algo.get_action_type() == ActionSpace.CONTINUOUS:
            pad_action = np.zeros(self._algo.impl.action_size, dtype=np.float32)
        else:
            pad_action = 0
        return pad_action


[docs]class TransformerAlgoBase( Generic[TTransformerImpl, TTransformerConfig], LearnableBase[TTransformerImpl, TTransformerConfig], ):
[docs] def predict(self, inpt: TransformerInput) -> NDArray: """Returns action. This is for internal use. For evaluation, use ``StatefulTransformerWrapper`` instead. Args: inpt: Sequence input. Returns: Action. """ assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR with torch.no_grad(): torch_inpt = TorchTransformerInput.from_numpy( inpt=inpt, context_size=self._config.context_size, device=self._device, observation_scaler=self._config.observation_scaler, action_scaler=self._config.action_scaler, reward_scaler=self._config.reward_scaler, ) action = self._impl.predict(torch_inpt) if self._config.action_scaler: action = self._config.action_scaler.reverse_transform(action) return action.cpu().detach().numpy() # type: ignore
[docs] def fit( self, dataset: ReplayBuffer, n_steps: int, n_steps_per_epoch: int = 10000, experiment_name: Optional[str] = None, with_timestamp: bool = True, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, eval_env: Optional[GymEnv] = None, eval_target_return: Optional[float] = None, eval_action_sampler: Optional[TransformerActionSampler] = None, save_interval: int = 1, callback: Optional[Callable[[Self, int, int], None]] = None, ) -> None: """Trains with given dataset. Args: dataset: Offline dataset to train. n_steps: Number of steps to train. n_steps_per_epoch: Number of steps per epoch. This value will be ignored when ``n_steps`` is ``None``. experiment_name: Experiment name for logging. If not passed, the directory name will be `{class name}_{timestamp}`. with_timestamp: Flag to add timestamp string to the last of directory name. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. eval_env: Evaluation environment. eval_target_return: Evaluation return target. eval_action_sampler: Action sampler used in evaluation. save_interval: Interval to save parameters. callback: Callable function that takes ``(algo, epoch, total_step)`` , which is called every step. """ LOG.info("dataset info", dataset_info=dataset.dataset_info) # check action space assert_action_space_with_dataset(self, dataset.dataset_info) # initialize scalers build_scalers_with_trajectory_slicer(self, dataset) # setup logger if experiment_name is None: experiment_name = self.__class__.__name__ logger = D3RLPyLogger( adapter_factory=logger_adapter, experiment_name=experiment_name, with_timestamp=with_timestamp, ) # instantiate implementation if self._impl is None: LOG.debug("Building models...") action_size = dataset.dataset_info.action_size observation_shape = ( dataset.sample_transition().observation_signature.shape[0] ) self.create_impl(observation_shape, action_size) LOG.debug("Models have been built.") else: LOG.warning("Skip building models since they're already built.") # save hyperparameters save_config(self, logger) # training loop n_epochs = n_steps // n_steps_per_epoch total_step = 0 for epoch in range(1, n_epochs + 1): # dict to add incremental mean losses to epoch epoch_loss = defaultdict(list) range_gen = tqdm( range(n_steps_per_epoch), disable=not show_progress, desc=f"Epoch {int(epoch)}/{n_epochs}", ) for itr in range_gen: with logger.measure_time("step"): # pick transitions with logger.measure_time("sample_batch"): batch = dataset.sample_trajectory_batch( self._config.batch_size, length=self._config.context_size, ) # update parameters with logger.measure_time("algorithm_update"): loss = self.update(batch) # record metrics for name, val in loss.items(): logger.add_metric(name, val) epoch_loss[name].append(val) # update progress postfix with losses if itr % 10 == 0: mean_loss = { k: np.mean(v) for k, v in epoch_loss.items() } range_gen.set_postfix(mean_loss) total_step += 1 # call callback if given if callback: callback(self, epoch, total_step) if eval_env: assert eval_target_return is not None eval_score = evaluate_transformer_with_environment( algo=self.as_stateful_wrapper( target_return=eval_target_return, action_sampler=eval_action_sampler, ), env=eval_env, ) logger.add_metric("environment", eval_score) # save metrics logger.commit(epoch, total_step) # save model parameters if epoch % save_interval == 0: logger.save_model(total_step, self) logger.close()
[docs] def update(self, batch: TrajectoryMiniBatch) -> Dict[str, float]: """Update parameters with mini-batch of data. Args: batch: Mini-batch data. Returns: Dictionary of metrics. """ assert self._impl, IMPL_NOT_INITIALIZED_ERROR torch_batch = TorchTrajectoryMiniBatch.from_batch( batch=batch, device=self._device, observation_scaler=self._config.observation_scaler, action_scaler=self._config.action_scaler, reward_scaler=self._config.reward_scaler, ) loss = self._impl.inner_update(torch_batch, self._grad_step) self._grad_step += 1 return loss
[docs] def as_stateful_wrapper( self, target_return: float, action_sampler: Optional[TransformerActionSampler] = None, ) -> StatefulTransformerWrapper[TTransformerImpl, TTransformerConfig]: """Returns a wrapped Transformer algorithm for stateful decision making. Args: target_return: Target environment return. action_sampler: Action sampler. Returns: StatefulTransformerWrapper object. """ if action_sampler is None: if self.get_action_type() == ActionSpace.CONTINUOUS: action_sampler = IdentityTransformerActionSampler() else: action_sampler = SoftmaxTransformerActionSampler() return StatefulTransformerWrapper(self, target_return, action_sampler)