Source code for d3rlpy.algos.transformer.decision_transformer
import dataclasses
from typing import Dict
import torch
from ...base import DeviceArg, register_learnable
from ...constants import ActionSpace
from ...dataset import Shape
from ...models import (
EncoderFactory,
OptimizerFactory,
make_encoder_field,
make_optimizer_field,
)
from ...models.builders import create_continuous_decision_transformer
from ...torch_utility import TorchTrajectoryMiniBatch
from .base import TransformerAlgoBase, TransformerConfig
from .torch.decision_transformer_impl import DecisionTransformerImpl
__all__ = ["DecisionTransformerConfig", "DecisionTransformer"]
[docs]@dataclasses.dataclass()
class DecisionTransformerConfig(TransformerConfig):
"""Config of Decision Transformer.
Decision Transformer solves decision-making problems as a sequence modeling
problem.
References:
* `Chen at el., Decision Transformer: Reinforcement Learning via
Sequence Modeling. <https://arxiv.org/abs/2106.01345>`_
Args:
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor.
context_size (int): Prior sequence length.
batch_size (int): Mini-batch size.
learning_rate (float): Learning rate.
encoder_factory (d3rlpy.models.encoders.EncoderFactory):
Encoder factory.
optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
Optimizer factory.
num_heads (int): Number of attention heads.
max_timestep (int): Maximum environmental timestep.
num_layers (int): Number of attention blocks.
attn_dropout (float): Dropout probability for attentions.
resid_dropout (float): Dropout probability for residual connection.
embed_dropout (float): Dropout probability for embeddings.
activation_type (str): Type of activation function.
position_encoding_type (str): Type of positional encoding
(``simple`` or ``global``).
warmup_steps (int): Warmup steps for learning rate scheduler.
clip_grad_norm (float): Norm of gradient clipping.
"""
batch_size: int = 64
learning_rate: float = 1e-4
encoder_factory: EncoderFactory = make_encoder_field()
optim_factory: OptimizerFactory = make_optimizer_field()
num_heads: int = 1
max_timestep: int = 1000
num_layers: int = 3
attn_dropout: float = 0.1
resid_dropout: float = 0.1
embed_dropout: float = 0.1
activation_type: str = "relu"
position_encoding_type: str = "simple"
warmup_steps: int = 10000
clip_grad_norm: float = 0.25
[docs] def create(self, device: DeviceArg = False) -> "DecisionTransformer":
return DecisionTransformer(self, device)
@staticmethod
def get_type() -> str:
return "decision_transformer"
[docs]class DecisionTransformer(
TransformerAlgoBase[DecisionTransformerImpl, DecisionTransformerConfig]
):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
transformer = create_continuous_decision_transformer(
observation_shape=observation_shape,
action_size=action_size,
encoder_factory=self._config.encoder_factory,
num_heads=self._config.num_heads,
max_timestep=self._config.max_timestep,
num_layers=self._config.num_layers,
context_size=self._config.context_size,
attn_dropout=self._config.attn_dropout,
resid_dropout=self._config.resid_dropout,
embed_dropout=self._config.embed_dropout,
activation_type=self._config.activation_type,
position_encoding_type=self._config.position_encoding_type,
device=self._device,
)
optim = self._config.optim_factory.create(
transformer.parameters(), lr=self._config.learning_rate
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optim, lambda steps: min((steps + 1) / self._config.warmup_steps, 1)
)
self._impl = DecisionTransformerImpl(
observation_shape=observation_shape,
action_size=action_size,
transformer=transformer,
optim=optim,
scheduler=scheduler,
clip_grad_norm=self._config.clip_grad_norm,
device=self._device,
)
[docs] def inner_update(self, batch: TorchTrajectoryMiniBatch) -> Dict[str, float]:
assert self._impl
loss = self._impl.update(batch)
return {"loss": loss}
register_learnable(DecisionTransformerConfig)