Source code for d3rlpy.algos.transformer.base

import dataclasses
from abc import abstractmethod
from collections import defaultdict, deque
from typing import (
    Callable,
    Deque,
    Dict,
    Generic,
    Optional,
    Sequence,
    TypeVar,
    Union,
)

import numpy as np
import torch
from tqdm.auto import tqdm
from typing_extensions import Self

from ...base import ImplBase, LearnableBase, LearnableConfig, save_config
from ...constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from ...dataset import ReplayBuffer, TrajectoryMiniBatch, is_tuple_shape
from ...logging import (
    LOG,
    D3RLPyLogger,
    FileAdapterFactory,
    LoggerAdapterFactory,
)
from ...metrics import evaluate_transformer_with_environment
from ...torch_utility import TorchTrajectoryMiniBatch, eval_api, train_api
from ...types import GymEnv, NDArray, Observation, TorchObservation
from ..utility import (
    assert_action_space_with_dataset,
    build_scalers_with_trajectory_slicer,
)
from .action_samplers import (
    IdentityTransformerActionSampler,
    SoftmaxTransformerActionSampler,
    TransformerActionSampler,
)
from .inputs import TorchTransformerInput, TransformerInput

__all__ = [
    "TransformerAlgoImplBase",
    "StatefulTransformerWrapper",
    "TransformerConfig",
    "TransformerAlgoBase",
]


class TransformerAlgoImplBase(ImplBase):
    @eval_api
    def predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
        return self.inner_predict(inpt)

    @abstractmethod
    def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor:
        raise NotImplementedError

    @train_api
    def update(
        self, batch: TorchTrajectoryMiniBatch, grad_step: int
    ) -> Dict[str, float]:
        return self.inner_update(batch, grad_step)

    @abstractmethod
    def inner_update(
        self, batch: TorchTrajectoryMiniBatch, grad_step: int
    ) -> Dict[str, float]:
        raise NotImplementedError


@dataclasses.dataclass()
class TransformerConfig(LearnableConfig):
    context_size: int = 20
    max_timestep: int = 1000


TTransformerImpl = TypeVar("TTransformerImpl", bound=TransformerAlgoImplBase)
TTransformerConfig = TypeVar("TTransformerConfig", bound=TransformerConfig)


class StatefulTransformerWrapper(Generic[TTransformerImpl, TTransformerConfig]):
    r"""A stateful wrapper for inference of Transformer-based algorithms.

    This wrapper class provides a similar interface of Q-learning-based
    algoritms, which is especially useful when you evaluate Transformer-based
    algorithms such as Decision Transformer.

    .. code-block:: python

        from d3rlpy.algos import DecisionTransformerConfig
        from d3rlpy.algos import StatefulTransformerWrapper

        dt = DecisionTransformerConfig().create()
        dt.create_impl(<observation_shape>, <action_size>)
        # initialize wrapper with a target return of 1000
        wrapper = StatefulTransformerWrapper(dt, target_return=1000)
        # shortcut is also available
        wrapper = dt.as_stateful_wrapper(target_return=1000)

        # predict next action to achieve the return of 1000 in the end
        action = wrapper.predict(<observation>, <reward>)

        # clear stateful information
        wrapper.reset()

    Args:
        algo (TransformerAlgoBase): Transformer-based algorithm.
        target_return (float): Target return.
        action_sampler (d3rlpy.algos.TransformerActionSampler): Action sampler.
    """

    _algo: "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]"
    _target_return: float
    _action_sampler: TransformerActionSampler
    _return_rest: float
    _observations: Deque[Observation]
    _actions: Deque[Union[NDArray, int]]
    _rewards: Deque[float]
    _returns_to_go: Deque[float]
    _timesteps: Deque[int]
    _timestep: int

    def __init__(
        self,
        algo: "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]",
        target_return: float,
        action_sampler: TransformerActionSampler,
    ):
        assert algo.impl, "algo must be built."
        self._algo = algo
        self._target_return = target_return
        self._action_sampler = action_sampler
        self._return_rest = target_return

        context_size = algo.config.context_size
        self._observations = deque([], maxlen=context_size)
        self._actions = deque([self._get_pad_action()], maxlen=context_size)
        self._rewards = deque([], maxlen=context_size)
        self._returns_to_go = deque([], maxlen=context_size)
        self._timesteps = deque([], maxlen=context_size)
        self._timestep = 1

    def predict(self, x: Observation, reward: float) -> Union[NDArray, int]:
        r"""Returns action.

        Args:
            x: Observation.
            reward: Last reward.

        Returns:
            Action.
        """
        self._observations.append(x)
        self._rewards.append(reward)
        self._returns_to_go.append(self._return_rest - reward)
        self._timesteps.append(self._timestep)

        numpy_observations: Observation
        if isinstance(x, np.ndarray):
            numpy_observations = np.array(self._observations)
        else:
            numpy_observations = [
                np.array([o[i] for o in self._observations])
                for i in range(len(x))
            ]

        inpt = TransformerInput(
            observations=numpy_observations,
            actions=np.array(self._actions),
            rewards=np.array(self._rewards).reshape((-1, 1)),
            returns_to_go=np.array(self._returns_to_go).reshape((-1, 1)),
            timesteps=np.array(self._timesteps),
        )
        action = self._action_sampler(self._algo.predict(inpt))
        self._actions[-1] = action
        self._actions.append(self._get_pad_action())
        self._timestep = min(self._timestep + 1, self._algo.config.max_timestep)
        self._return_rest -= reward
        return action

    def reset(self) -> None:
        """Clears stateful information."""
        self._observations.clear()
        self._actions.clear()
        self._rewards.clear()
        self._returns_to_go.clear()
        self._timesteps.clear()
        self._actions.append(self._get_pad_action())
        self._timestep = 1
        self._return_rest = self._target_return

    @property
    def algo(
        self,
    ) -> "TransformerAlgoBase[TTransformerImpl, TTransformerConfig]":
        return self._algo

    def _get_pad_action(self) -> Union[int, NDArray]:
        assert self._algo.impl
        pad_action: Union[int, NDArray]
        if self._algo.get_action_type() == ActionSpace.CONTINUOUS:
            pad_action = np.zeros(self._algo.impl.action_size, dtype=np.float32)
        else:
            pad_action = 0
        return pad_action


[docs]class TransformerAlgoBase( Generic[TTransformerImpl, TTransformerConfig], LearnableBase[TTransformerImpl, TTransformerConfig], ):
[docs] def save_policy(self, fname: str) -> None: """Save the greedy-policy computational graph as TorchScript or ONNX. The format will be automatically detected by the file name. .. code-block:: python # save as TorchScript algo.save_policy('policy.pt') # save as ONNX algo.save_policy('policy.onnx') The artifacts saved with this method will work without d3rlpy. This method is especially useful to deploy the learned policy to production environments or embedding systems. See also * https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html (for Python). * https://pytorch.org/tutorials/advanced/cpp_export.html (for C++). * https://onnx.ai (for ONNX) Visit https://d3rlpy.readthedocs.io/en/stable/tutorials/after_training_policies.html#export-policies-as-torchscript for the further usage. Args: fname: Destination file path. """ assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR context_size = self._config.context_size dummy_x = [] if is_tuple_shape(self._impl.observation_shape): dummy_x.extend( [ torch.rand(context_size, *shape, device=self._device) for shape in self._impl.observation_shape ] ) num_observations = len(self._impl.observation_shape) else: dummy_x.append( torch.rand( context_size, *self._impl.observation_shape, device=self._device, ) ) num_observations = 1 # action if self.get_action_type() == ActionSpace.CONTINUOUS: dummy_x.append( torch.rand( context_size, self._impl.action_size, device=self._device ) ) else: dummy_x.append(torch.rand(context_size, 1, device=self._device)) # return_to_go dummy_x.append(torch.rand(context_size, 1, device=self._device)) # timesteps dummy_x.append(torch.arange(context_size, device=self._device)) # workaround until version 1.6 self._impl.modules.freeze() # local function to select best actions def _func(*x: Sequence[torch.Tensor]) -> torch.Tensor: assert self._impl # add batch dimension x = [v.view(1, *v.shape) for v in x] # type: ignore observations: TorchObservation = x[:-3] actions = x[-3] returns_to_go = x[-2] timesteps = x[-1] if len(observations) == 1: observations = observations[0] if self._config.observation_scaler: observations = self._config.observation_scaler.transform( observations ) if self._config.action_scaler: actions = self._config.action_scaler.transform(actions) inpt = TorchTransformerInput( observations=observations, actions=actions, rewards=torch.zeros_like(returns_to_go), returns_to_go=returns_to_go, timesteps=timesteps, masks=torch.zeros_like(returns_to_go), length=self._config.context_size, ) action = self._impl.predict(inpt) if self._config.action_scaler: action = self._config.action_scaler.reverse_transform(action) if self.get_action_type() == ActionSpace.DISCRETE: action = action.argmax() return action traced_script = torch.jit.trace(_func, dummy_x, check_trace=False) if fname.endswith(".onnx"): # currently, PyTorch cannot directly export function as ONNX. torch.onnx.export( traced_script, dummy_x, fname, export_params=True, opset_version=11, input_names=[ f"observation_{i}" for i in range(num_observations) ] + ["action", "return_to_go", "timestep"], output_names=["output_0"], ) elif fname.endswith(".pt"): traced_script.save(fname) else: raise ValueError( f"invalid format type: {fname}." " .pt and .onnx extensions are currently supported." ) # workaround until version 1.6 self._impl.modules.unfreeze()
[docs] def predict(self, inpt: TransformerInput) -> NDArray: """Returns action. This is for internal use. For evaluation, use ``StatefulTransformerWrapper`` instead. Args: inpt: Sequence input. Returns: Action. """ assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR with torch.no_grad(): torch_inpt = TorchTransformerInput.from_numpy( inpt=inpt, context_size=self._config.context_size, device=self._device, observation_scaler=self._config.observation_scaler, action_scaler=self._config.action_scaler, reward_scaler=self._config.reward_scaler, ) action = self._impl.predict(torch_inpt) if self._config.action_scaler: action = self._config.action_scaler.reverse_transform(action) return action.cpu().detach().numpy() # type: ignore
[docs] def fit( self, dataset: ReplayBuffer, n_steps: int, n_steps_per_epoch: int = 10000, experiment_name: Optional[str] = None, with_timestamp: bool = True, logger_adapter: LoggerAdapterFactory = FileAdapterFactory(), show_progress: bool = True, eval_env: Optional[GymEnv] = None, eval_target_return: Optional[float] = None, eval_action_sampler: Optional[TransformerActionSampler] = None, save_interval: int = 1, callback: Optional[Callable[[Self, int, int], None]] = None, enable_ddp: bool = False, ) -> None: """Trains with given dataset. Args: dataset: Offline dataset to train. n_steps: Number of steps to train. n_steps_per_epoch: Number of steps per epoch. This value will be ignored when ``n_steps`` is ``None``. experiment_name: Experiment name for logging. If not passed, the directory name will be `{class name}_{timestamp}`. with_timestamp: Flag to add timestamp string to the last of directory name. logger_adapter: LoggerAdapterFactory object. show_progress: Flag to show progress bar for iterations. eval_env: Evaluation environment. eval_target_return: Evaluation return target. eval_action_sampler: Action sampler used in evaluation. save_interval: Interval to save parameters. callback: Callable function that takes ``(algo, epoch, total_step)`` , which is called every step. enable_ddp: Flag to wrap models with DataDistributedParallel. """ LOG.info("dataset info", dataset_info=dataset.dataset_info) # check action space assert_action_space_with_dataset(self, dataset.dataset_info) # initialize scalers build_scalers_with_trajectory_slicer(self, dataset) # setup logger if experiment_name is None: experiment_name = self.__class__.__name__ logger = D3RLPyLogger( adapter_factory=logger_adapter, experiment_name=experiment_name, with_timestamp=with_timestamp, ) # instantiate implementation if self._impl is None: LOG.debug("Building models...") action_size = dataset.dataset_info.action_size observation_shape = ( dataset.sample_transition().observation_signature.shape ) if len(observation_shape) == 1: observation_shape = observation_shape[0] # type: ignore self.create_impl(observation_shape, action_size) LOG.debug("Models have been built.") else: LOG.warning("Skip building models since they're already built.") # wrap all PyTorch modules with DataDistributedParallel if enable_ddp: assert self._impl self._impl.wrap_models_by_ddp() # save hyperparameters save_config(self, logger) # training loop n_epochs = n_steps // n_steps_per_epoch total_step = 0 for epoch in range(1, n_epochs + 1): # dict to add incremental mean losses to epoch epoch_loss = defaultdict(list) range_gen = tqdm( range(n_steps_per_epoch), disable=not show_progress, desc=f"Epoch {int(epoch)}/{n_epochs}", ) for itr in range_gen: with logger.measure_time("step"): # pick transitions with logger.measure_time("sample_batch"): batch = dataset.sample_trajectory_batch( self._config.batch_size, length=self._config.context_size, ) # update parameters with logger.measure_time("algorithm_update"): loss = self.update(batch) # record metrics for name, val in loss.items(): logger.add_metric(name, val) epoch_loss[name].append(val) # update progress postfix with losses if itr % 10 == 0: mean_loss = { k: np.mean(v) for k, v in epoch_loss.items() } range_gen.set_postfix(mean_loss) total_step += 1 # call callback if given if callback: callback(self, epoch, total_step) if eval_env: assert eval_target_return is not None eval_score = evaluate_transformer_with_environment( algo=self.as_stateful_wrapper( target_return=eval_target_return, action_sampler=eval_action_sampler, ), env=eval_env, ) logger.add_metric("environment", eval_score) # save metrics logger.commit(epoch, total_step) # save model parameters if epoch % save_interval == 0: logger.save_model(total_step, self) logger.close()
[docs] def update(self, batch: TrajectoryMiniBatch) -> Dict[str, float]: """Update parameters with mini-batch of data. Args: batch: Mini-batch data. Returns: Dictionary of metrics. """ assert self._impl, IMPL_NOT_INITIALIZED_ERROR torch_batch = TorchTrajectoryMiniBatch.from_batch( batch=batch, device=self._device, observation_scaler=self._config.observation_scaler, action_scaler=self._config.action_scaler, reward_scaler=self._config.reward_scaler, ) loss = self._impl.update(torch_batch, self._grad_step) self._grad_step += 1 return loss
[docs] def as_stateful_wrapper( self, target_return: float, action_sampler: Optional[TransformerActionSampler] = None, ) -> StatefulTransformerWrapper[TTransformerImpl, TTransformerConfig]: """Returns a wrapped Transformer algorithm for stateful decision making. Args: target_return: Target environment return. action_sampler: Action sampler. Returns: StatefulTransformerWrapper object. """ if action_sampler is None: if self.get_action_type() == ActionSpace.CONTINUOUS: action_sampler = IdentityTransformerActionSampler() else: action_sampler = SoftmaxTransformerActionSampler() return StatefulTransformerWrapper(self, target_return, action_sampler)