Source code for d3rlpy.algos.qlearning.plas

import dataclasses

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...models.builders import (
    create_continuous_q_function,
    create_deterministic_policy,
    create_deterministic_residual_policy,
    create_vae_decoder,
    create_vae_encoder,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.q_functions import QFunctionFactory, make_q_func_field
from ...types import Shape
from .base import QLearningAlgoBase
from .torch.plas_impl import (
    PLASImpl,
    PLASModules,
    PLASWithPerturbationImpl,
    PLASWithPerturbationModules,
)

__all__ = [
    "PLASConfig",
    "PLAS",
    "PLASWithPerturbationConfig",
    "PLASWithPerturbation",
]


[docs]@dataclasses.dataclass() class PLASConfig(LearnableConfig): r"""Config of Policy in Latent Action Space algorithm. PLAS is an offline deep reinforcement learning algorithm whose policy function is trained in latent space of Conditional VAE. Unlike other algorithms, PLAS can achieve good performance by using its less constrained policy function. .. math:: a \sim p_\beta (a|s, z=\pi_\phi(s)) where :math:`\beta` is a parameter of the decoder in Conditional VAE. References: * `Zhou et al., PLAS: latent action space for offline reinforcement learning. <https://arxiv.org/abs/2011.07213>`_ Args: observation_scaler (d3rlpy.preprocessing.ObservationScaler): Observation preprocessor. action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. actor_learning_rate (float): Learning rate for policy function. critic_learning_rate (float): Learning rate for Q functions. imitator_learning_rate (float): Learning rate for Conditional VAE. actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): Optimizer factory for the actor. critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): Optimizer factory for the critic. imitator_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): Optimizer factory for the conditional VAE. actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory): Encoder factory for the actor. critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory): Encoder factory for the critic. imitator_encoder_factory (d3rlpy.models.encoders.EncoderFactory): Encoder factory for the conditional VAE. q_func_factory (d3rlpy.models.q_functions.QFunctionFactory): Q function factory. batch_size (int): Mini-batch size. gamma (float): Discount factor. tau (float): Target network synchronization coefficiency. n_critics (int): Number of Q functions for ensemble. lam (float): Weight factor for critic ensemble. warmup_steps (int): Number of steps to warmup the VAE. beta (float): KL reguralization term for Conditional VAE. """ actor_learning_rate: float = 1e-4 critic_learning_rate: float = 1e-3 imitator_learning_rate: float = 1e-4 actor_optim_factory: OptimizerFactory = make_optimizer_field() critic_optim_factory: OptimizerFactory = make_optimizer_field() imitator_optim_factory: OptimizerFactory = make_optimizer_field() actor_encoder_factory: EncoderFactory = make_encoder_field() critic_encoder_factory: EncoderFactory = make_encoder_field() imitator_encoder_factory: EncoderFactory = make_encoder_field() q_func_factory: QFunctionFactory = make_q_func_field() batch_size: int = 100 gamma: float = 0.99 tau: float = 0.005 n_critics: int = 2 lam: float = 0.75 warmup_steps: int = 500000 beta: float = 0.5
[docs] def create(self, device: DeviceArg = False) -> "PLAS": return PLAS(self, device)
@staticmethod def get_type() -> str: return "plas"
[docs]class PLAS(QLearningAlgoBase[PLASImpl, PLASConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: policy = create_deterministic_policy( observation_shape, 2 * action_size, self._config.actor_encoder_factory, device=self._device, ) targ_policy = create_deterministic_policy( observation_shape, 2 * action_size, self._config.actor_encoder_factory, device=self._device, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, ) vae_encoder = create_vae_encoder( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, min_logstd=-4.0, max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, device=self._device, ) vae_decoder = create_vae_decoder( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, encoder_factory=self._config.imitator_encoder_factory, device=self._device, ) actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate ) vae_optim = self._config.critic_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, ) modules = PLASModules( policy=policy, targ_policy=targ_policy, q_funcs=q_funcs, targ_q_funcs=targ_q_funcs, vae_encoder=vae_encoder, vae_decoder=vae_decoder, actor_optim=actor_optim, critic_optim=critic_optim, vae_optim=vae_optim, ) self._impl = PLASImpl( observation_shape=observation_shape, action_size=action_size, modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, device=self._device, )
[docs] def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS
[docs]@dataclasses.dataclass() class PLASWithPerturbationConfig(PLASConfig): r"""Config of Policy in Latent Action Space algorithm with perturbation layer. PLAS with perturbation layer enables PLAS to output out-of-distribution action. References: * `Zhou et al., PLAS: latent action space for offline reinforcement learning. <https://arxiv.org/abs/2011.07213>`_ Args: observation_scaler (d3rlpy.preprocessing.ObservationScaler): Observation preprocessor. action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. reward_scaler (d3rlpy.preprocessing.RewardScaler): Reward preprocessor. actor_learning_rate (float): Learning rate for policy function. critic_learning_rate (float): Learning rate for Q functions. imitator_learning_rate (float): Learning rate for Conditional VAE. actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): Optimizer factory for the actor. critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): Optimizer factory for the critic. imitator_optim_factory (d3rlpy.models.optimizers.OptimizerFactory): Optimizer factory for the conditional VAE. actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory): Encoder factory for the actor. critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory): Encoder factory for the critic. imitator_encoder_factory (d3rlpy.models.encoders.EncoderFactory): Encoder factory for the conditional VAE. q_func_factory (d3rlpy.models.q_functions.QFunctionFactory): Q function factory. batch_size (int): Mini-batch size. gamma (float): Discount factor. tau (float): Target network synchronization coefficiency. n_critics (int): Number of Q functions for ensemble. update_actor_interval (int): Interval to update policy function. lam (float): Weight factor for critic ensemble. action_flexibility (float): Output scale of perturbation layer. warmup_steps (int): Number of steps to warmup the VAE. beta (float): KL reguralization term for Conditional VAE. """ action_flexibility: float = 0.05
[docs] def create(self, device: DeviceArg = False) -> "PLASWithPerturbation": return PLASWithPerturbation(self, device)
@staticmethod def get_type() -> str: return "plas_with_perturbation"
[docs]class PLASWithPerturbation(PLAS): _config: PLASWithPerturbationConfig def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: policy = create_deterministic_policy( observation_shape, 2 * action_size, self._config.actor_encoder_factory, device=self._device, ) targ_policy = create_deterministic_policy( observation_shape, 2 * action_size, self._config.actor_encoder_factory, device=self._device, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, action_size, self._config.critic_encoder_factory, self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, ) vae_encoder = create_vae_encoder( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, min_logstd=-4.0, max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, device=self._device, ) vae_decoder = create_vae_decoder( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, encoder_factory=self._config.imitator_encoder_factory, device=self._device, ) perturbation = create_deterministic_residual_policy( observation_shape=observation_shape, action_size=action_size, scale=self._config.action_flexibility, encoder_factory=self._config.actor_encoder_factory, device=self._device, ) targ_perturbation = create_deterministic_residual_policy( observation_shape=observation_shape, action_size=action_size, scale=self._config.action_flexibility, encoder_factory=self._config.actor_encoder_factory, device=self._device, ) named_modules = list(policy.named_modules()) named_modules += list(perturbation.named_modules()) actor_optim = self._config.actor_optim_factory.create( named_modules, lr=self._config.actor_learning_rate ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate ) vae_optim = self._config.critic_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, ) modules = PLASWithPerturbationModules( policy=policy, targ_policy=targ_policy, q_funcs=q_funcs, targ_q_funcs=targ_q_funcs, vae_encoder=vae_encoder, vae_decoder=vae_decoder, perturbation=perturbation, targ_perturbation=targ_perturbation, actor_optim=actor_optim, critic_optim=critic_optim, vae_optim=vae_optim, ) self._impl = PLASWithPerturbationImpl( observation_shape=observation_shape, action_size=action_size, modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, device=self._device, )
register_learnable(PLASConfig) register_learnable(PLASWithPerturbationConfig)