from typing import Any, List, Optional, Sequence
from .base import AlgoBase, DataGenerator
from .torch.plas_impl import PLASImpl, PLASWithPerturbationImpl
from ..augmentation import AugmentationPipeline
from ..dataset import TransitionMiniBatch
from ..models.encoders import EncoderFactory
from ..models.optimizers import OptimizerFactory, AdamFactory
from ..models.q_functions import QFunctionFactory
from ..gpu import Device
from ..argument_utility import check_encoder, EncoderArg
from ..argument_utility import check_use_gpu, UseGPUArg
from ..argument_utility import check_augmentation, AugmentationArg
from ..argument_utility import check_q_func, QFuncArg
from ..argument_utility import ScalerArg, ActionScalerArg
from ..constants import IMPL_NOT_INITIALIZED_ERROR
[docs]class PLAS(AlgoBase):
r"""Policy in Latent Action Space algorithm.
PLAS is an offline deep reinforcement learning algorithm whose policy
function is trained in latent space of Conditional VAE.
Unlike other algorithms, PLAS can achieve good performance by using
its less constrained policy function.
.. math::
a \sim p_\beta (a|s, z=\pi_\phi(s))
where :math:`\beta` is a parameter of the decoder in Conditional VAE.
References:
* `Zhou et al., PLAS: latent action space for offline reinforcement
learning. <https://arxiv.org/abs/2011.07213>`_
Args:
actor_learning_rate (float): learning rate for policy function.
critic_learning_rate (float): learning rate for Q functions.
imitator_learning_rate (float): learning rate for Conditional VAE.
actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
optimizer factory for the actor.
critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
optimizer factory for the critic.
imitator_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
optimizer factory for the conditional VAE.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
encoder factory for the actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
encoder factory for the critic.
imitator_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
encoder factory for the conditional VAE.
q_func_factory (d3rlpy.models.q_functions.QFunctionFactory or str):
Q function factory.
batch_size (int): mini-batch size.
n_frames (int): the number of frames to stack for image observation.
n_steps (int): N-step TD calculation.
gamma (float): discount factor.
tau (float): target network synchronization coefficiency.
n_critics (int): the number of Q functions for ensemble.
bootstrap (bool): flag to bootstrap Q functions.
share_encoder (bool): flag to share encoder network.
target_reduction_type (str): ensemble reduction method at target value
estimation. The available options are
``['min', 'max', 'mean', 'mix', 'none']``.
update_actor_interval (int): interval to update policy function.
lam (float): weight factor for critic ensemble.
rl_start_epoch (int): epoch to start to update policy function and Q
functions. If this is large, RL training would be more stabilized.
beta (float): KL reguralization term for Conditional VAE.
use_gpu (bool, int or d3rlpy.gpu.Device):
flag to use GPU, device ID or device.
scaler (d3rlpy.preprocessing.Scaler or str): preprocessor.
The available options are `['pixel', 'min_max', 'standard']`.
action_scaler (d3rlpy.preprocessing.ActionScaler or str):
action preprocessor. The available options are ``['min_max']``.
augmentation (d3rlpy.augmentation.AugmentationPipeline or list(str)):
augmentation pipeline.
generator (d3rlpy.algos.base.DataGenerator): dynamic dataset generator
(e.g. model-based RL).
impl (d3rlpy.algos.torch.bcq_impl.BCQImpl): algorithm implementation.
"""
_actor_learning_rate: float
_critic_learning_rate: float
_imitator_learning_rate: float
_actor_optim_factory: OptimizerFactory
_critic_optim_factory: OptimizerFactory
_imitator_optim_factory: OptimizerFactory
_actor_encoder_factory: EncoderFactory
_critic_encoder_factory: EncoderFactory
_imitator_encoder_factory: EncoderFactory
_q_func_factory: QFunctionFactory
_tau: float
_bootstrap: bool
_n_critics: int
_share_encoder: bool
_target_reduction_type: str
_update_actor_interval: int
_lam: float
_rl_start_epoch: int
_beta: float
_augmentation: AugmentationPipeline
_use_gpu: Optional[Device]
_impl: Optional[PLASImpl]
def __init__(
self,
*,
actor_learning_rate: float = 3e-4,
critic_learning_rate: float = 3e-4,
imitator_learning_rate: float = 3e-4,
actor_optim_factory: OptimizerFactory = AdamFactory(),
critic_optim_factory: OptimizerFactory = AdamFactory(),
imitator_optim_factory: OptimizerFactory = AdamFactory(),
actor_encoder_factory: EncoderArg = "default",
critic_encoder_factory: EncoderArg = "default",
imitator_encoder_factory: EncoderArg = "default",
q_func_factory: QFuncArg = "mean",
batch_size: int = 256,
n_frames: int = 1,
n_steps: int = 1,
gamma: float = 0.99,
tau: float = 0.005,
n_critics: int = 2,
bootstrap: bool = False,
share_encoder: bool = False,
target_reduction_type: str = "mix",
update_actor_interval: int = 1,
lam: float = 0.75,
rl_start_epoch: int = 10,
beta: float = 0.5,
use_gpu: UseGPUArg = False,
scaler: ScalerArg = None,
action_scaler: ActionScalerArg = None,
augmentation: AugmentationArg = None,
generator: Optional[DataGenerator] = None,
impl: Optional[PLASImpl] = None,
**kwargs: Any
):
super().__init__(
batch_size=batch_size,
n_frames=n_frames,
n_steps=n_steps,
gamma=gamma,
scaler=scaler,
action_scaler=action_scaler,
generator=generator,
)
self._actor_learning_rate = actor_learning_rate
self._critic_learning_rate = critic_learning_rate
self._imitator_learning_rate = imitator_learning_rate
self._actor_optim_factory = actor_optim_factory
self._critic_optim_factory = critic_optim_factory
self._imitator_optim_factory = imitator_optim_factory
self._actor_encoder_factory = check_encoder(actor_encoder_factory)
self._critic_encoder_factory = check_encoder(critic_encoder_factory)
self._imitator_encoder_factory = check_encoder(imitator_encoder_factory)
self._q_func_factory = check_q_func(q_func_factory)
self._tau = tau
self._bootstrap = bootstrap
self._n_critics = n_critics
self._share_encoder = share_encoder
self._target_reduction_type = target_reduction_type
self._update_actor_interval = update_actor_interval
self._lam = lam
self._rl_start_epoch = rl_start_epoch
self._beta = beta
self._augmentation = check_augmentation(augmentation)
self._use_gpu = check_use_gpu(use_gpu)
self._impl = impl
[docs] def create_impl(
self, observation_shape: Sequence[int], action_size: int
) -> None:
self._impl = PLASImpl(
observation_shape=observation_shape,
action_size=action_size,
actor_learning_rate=self._actor_learning_rate,
critic_learning_rate=self._critic_learning_rate,
imitator_learning_rate=self._imitator_learning_rate,
actor_optim_factory=self._actor_optim_factory,
critic_optim_factory=self._critic_optim_factory,
imitator_optim_factory=self._imitator_optim_factory,
actor_encoder_factory=self._actor_encoder_factory,
critic_encoder_factory=self._critic_encoder_factory,
imitator_encoder_factory=self._imitator_encoder_factory,
q_func_factory=self._q_func_factory,
gamma=self._gamma,
tau=self._tau,
n_critics=self._n_critics,
bootstrap=self._bootstrap,
share_encoder=self._share_encoder,
target_reduction_type=self._target_reduction_type,
lam=self._lam,
beta=self._beta,
use_gpu=self._use_gpu,
scaler=self._scaler,
action_scaler=self._action_scaler,
augmentation=self._augmentation,
)
self._impl.build()
[docs] def update(
self, epoch: int, total_step: int, batch: TransitionMiniBatch
) -> List[Optional[float]]:
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
if epoch < self._rl_start_epoch:
imitator_loss = self._impl.update_imitator(
batch.observations, batch.actions
)
critic_loss, actor_loss = None, None
else:
critic_loss = self._impl.update_critic(
batch.observations,
batch.actions,
batch.next_rewards,
batch.next_observations,
batch.terminals,
batch.n_steps,
batch.masks,
)
if total_step % self._update_actor_interval == 0:
actor_loss = self._impl.update_actor(batch.observations)
self._impl.update_actor_target()
self._impl.update_critic_target()
else:
actor_loss = None
imitator_loss = None
return [critic_loss, actor_loss, imitator_loss]
[docs] def get_loss_labels(self) -> List[str]:
return ["critic_loss", "actor_loss", "imitator_loss"]
[docs]class PLASWithPerturbation(PLAS):
r"""Policy in Latent Action Space algorithm with perturbation layer.
PLAS with perturbation layer enables PLAS to output out-of-distribution
action.
References:
* `Zhou et al., PLAS: latent action space for offline reinforcement
learning. <https://arxiv.org/abs/2011.07213>`_
Args:
actor_learning_rate (float): learning rate for policy function.
critic_learning_rate (float): learning rate for Q functions.
imitator_learning_rate (float): learning rate for Conditional VAE.
actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
optimizer factory for the actor.
critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
optimizer factory for the critic.
imitator_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
optimizer factory for the conditional VAE.
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
encoder factory for the actor.
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
encoder factory for the critic.
imitator_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
encoder factory for the conditional VAE.
q_func_factory (d3rlpy.models.q_functions.QFunctionFactory or str):
Q function factory.
batch_size (int): mini-batch size.
n_frames (int): the number of frames to stack for image observation.
n_steps (int): N-step TD calculation.
gamma (float): discount factor.
tau (float): target network synchronization coefficiency.
n_critics (int): the number of Q functions for ensemble.
bootstrap (bool): flag to bootstrap Q functions.
share_encoder (bool): flag to share encoder network.
target_reduction_type (str): ensemble reduction method at target value
estimation. The available options are
``['min', 'max', 'mean', 'mix', 'none']``.
update_actor_interval (int): interval to update policy function.
lam (float): weight factor for critic ensemble.
action_flexibility (float): output scale of perturbation layer.
rl_start_epoch (int): epoch to start to update policy function and Q
functions. If this is large, RL training would be more stabilized.
beta (float): KL reguralization term for Conditional VAE.
use_gpu (bool, int or d3rlpy.gpu.Device):
flag to use GPU, device ID or device.
scaler (d3rlpy.preprocessing.Scaler or str): preprocessor.
The available options are `['pixel', 'min_max', 'standard']`.
action_scaler (d3rlpy.preprocessing.ActionScaler or str):
action preprocessor. The available options are ``['min_max']``.
augmentation (d3rlpy.augmentation.AugmentationPipeline or list(str)):
augmentation pipeline.
generator (d3rlpy.algos.base.DataGenerator): dynamic dataset generator
(e.g. model-based RL).
impl (d3rlpy.algos.torch.bcq_impl.BCQImpl): algorithm implementation.
"""
_action_flexibility: float
_impl: Optional[PLASWithPerturbationImpl]
def __init__(
self,
*,
actor_learning_rate: float = 3e-4,
critic_learning_rate: float = 3e-4,
imitator_learning_rate: float = 3e-4,
actor_optim_factory: OptimizerFactory = AdamFactory(),
critic_optim_factory: OptimizerFactory = AdamFactory(),
imitator_optim_factory: OptimizerFactory = AdamFactory(),
actor_encoder_factory: EncoderArg = "default",
critic_encoder_factory: EncoderArg = "default",
imitator_encoder_factory: EncoderArg = "default",
q_func_factory: QFuncArg = "mean",
batch_size: int = 256,
n_frames: int = 1,
n_steps: int = 1,
gamma: float = 0.99,
tau: float = 0.005,
n_critics: int = 2,
bootstrap: bool = False,
share_encoder: bool = False,
target_reduction_type: str = "mix",
update_actor_interval: int = 1,
lam: float = 0.75,
action_flexibility: float = 0.05,
rl_start_epoch: int = 10,
beta: float = 0.5,
use_gpu: UseGPUArg = False,
scaler: ScalerArg = None,
action_scaler: ActionScalerArg = None,
augmentation: AugmentationArg = None,
generator: Optional[DataGenerator] = None,
impl: Optional[PLASWithPerturbationImpl] = None,
**kwargs: Any
):
super().__init__(
actor_learning_rate=actor_learning_rate,
critic_learning_rate=critic_learning_rate,
imitator_learning_rate=imitator_learning_rate,
actor_optim_factory=actor_optim_factory,
critic_optim_factory=critic_optim_factory,
imitator_optim_factory=imitator_optim_factory,
actor_encoder_factory=actor_encoder_factory,
critic_encoder_factory=critic_encoder_factory,
imitator_encoder_factory=imitator_encoder_factory,
q_func_factory=q_func_factory,
batch_size=batch_size,
n_frames=n_frames,
n_steps=n_steps,
gamma=gamma,
tau=tau,
n_critics=n_critics,
bootstrap=bootstrap,
share_encoder=share_encoder,
target_reduction_type=target_reduction_type,
update_actor_interval=update_actor_interval,
lam=lam,
rl_start_epoch=rl_start_epoch,
beta=beta,
use_gpu=use_gpu,
scaler=scaler,
action_scaler=action_scaler,
augmentation=augmentation,
generator=generator,
impl=impl,
)
self._action_flexibility = action_flexibility
[docs] def create_impl(
self, observation_shape: Sequence[int], action_size: int
) -> None:
self._impl = PLASWithPerturbationImpl(
observation_shape=observation_shape,
action_size=action_size,
actor_learning_rate=self._actor_learning_rate,
critic_learning_rate=self._critic_learning_rate,
imitator_learning_rate=self._imitator_learning_rate,
actor_optim_factory=self._actor_optim_factory,
critic_optim_factory=self._critic_optim_factory,
imitator_optim_factory=self._imitator_optim_factory,
actor_encoder_factory=self._actor_encoder_factory,
critic_encoder_factory=self._critic_encoder_factory,
imitator_encoder_factory=self._imitator_encoder_factory,
q_func_factory=self._q_func_factory,
gamma=self._gamma,
tau=self._tau,
n_critics=self._n_critics,
bootstrap=self._bootstrap,
share_encoder=self._share_encoder,
target_reduction_type=self._target_reduction_type,
lam=self._lam,
beta=self._beta,
action_flexibility=self._action_flexibility,
use_gpu=self._use_gpu,
scaler=self._scaler,
action_scaler=self._action_scaler,
augmentation=self._augmentation,
)
self._impl.build()