Source code for d3rlpy.models.q_functions

from typing import Any, ClassVar, Dict, Type

from ..decorators import pretty_repr
from .torch import (
    ContinuousFQFQFunction,
    ContinuousIQNQFunction,
    ContinuousMeanQFunction,
    ContinuousQFunction,
    ContinuousQRQFunction,
    DiscreteFQFQFunction,
    DiscreteIQNQFunction,
    DiscreteMeanQFunction,
    DiscreteQFunction,
    DiscreteQRQFunction,
    Encoder,
    EncoderWithAction,
)


@pretty_repr
class QFunctionFactory:
    TYPE: ClassVar[str] = "none"

    _bootstrap: bool
    _share_encoder: bool

    def __init__(self, bootstrap: bool, share_encoder: bool):
        self._bootstrap = bootstrap
        self._share_encoder = share_encoder

    def create_discrete(
        self, encoder: Encoder, action_size: int
    ) -> DiscreteQFunction:
        """Returns PyTorch's Q function module.

        Args:
            encoder: an encoder module that processes the observation to
                obtain feature representations.
            action_size: dimension of discrete action-space.

        Returns:
            discrete Q function object.

        """
        raise NotImplementedError

    def create_continuous(
        self, encoder: EncoderWithAction
    ) -> ContinuousQFunction:
        """Returns PyTorch's Q function module.

        Args:
            encoder: an encoder module that processes the observation and
                action to obtain feature representations.

        Returns:
            continuous Q function object.

        """
        raise NotImplementedError

    def get_type(self) -> str:
        """Returns Q function type.

        Returns:
            Q function type.

        """
        return self.TYPE

    def get_params(self, deep: bool = False) -> Dict[str, Any]:
        """Returns Q function parameters.

        Returns:
            Q function parameters.

        """
        raise NotImplementedError

    @property
    def bootstrap(self) -> bool:
        return self._bootstrap

    @property
    def share_encoder(self) -> bool:
        return self._share_encoder


[docs]class MeanQFunctionFactory(QFunctionFactory): """Standard Q function factory class. This is the standard Q function factory class. References: * `Mnih et al., Human-level control through deep reinforcement learning. <https://www.nature.com/articles/nature14236>`_ * `Lillicrap et al., Continuous control with deep reinforcement learning. <https://arxiv.org/abs/1509.02971>`_ Args: bootstrap (bool): flag to bootstrap Q functions. share_encoder (bool): flag to share encoder over multiple Q functions. """ TYPE: ClassVar[str] = "mean" def __init__(self, bootstrap: bool = False, share_encoder: bool = False): super().__init__(bootstrap, share_encoder)
[docs] def create_discrete( self, encoder: Encoder, action_size: int, ) -> DiscreteMeanQFunction: return DiscreteMeanQFunction(encoder, action_size)
[docs] def create_continuous( self, encoder: EncoderWithAction, ) -> ContinuousMeanQFunction: return ContinuousMeanQFunction(encoder)
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]: return { "bootstrap": self._bootstrap, "share_encoder": self._share_encoder, }
[docs]class QRQFunctionFactory(QFunctionFactory): """Quantile Regression Q function factory class. References: * `Dabney et al., Distributional reinforcement learning with quantile regression. <https://arxiv.org/abs/1710.10044>`_ Args: bootstrap (bool): flag to bootstrap Q functions. share_encoder (bool): flag to share encoder over multiple Q functions. n_quantiles: the number of quantiles. """ TYPE: ClassVar[str] = "qr" _n_quantiles: int def __init__( self, bootstrap: bool = False, share_encoder: bool = False, n_quantiles: int = 32, ): super().__init__(bootstrap, share_encoder) self._n_quantiles = n_quantiles
[docs] def create_discrete( self, encoder: Encoder, action_size: int ) -> DiscreteQRQFunction: return DiscreteQRQFunction(encoder, action_size, self._n_quantiles)
[docs] def create_continuous( self, encoder: EncoderWithAction, ) -> ContinuousQRQFunction: return ContinuousQRQFunction(encoder, self._n_quantiles)
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]: return { "bootstrap": self._bootstrap, "share_encoder": self._share_encoder, "n_quantiles": self._n_quantiles, }
@property def n_quantiles(self) -> int: return self._n_quantiles
[docs]class IQNQFunctionFactory(QFunctionFactory): """Implicit Quantile Network Q function factory class. References: * `Dabney et al., Implicit quantile networks for distributional reinforcement learning. <https://arxiv.org/abs/1806.06923>`_ Args: bootstrap (bool): flag to bootstrap Q functions. share_encoder (bool): flag to share encoder over multiple Q functions. n_quantiles: the number of quantiles. n_greedy_quantiles: the number of quantiles for inference. embed_size: the embedding size. """ TYPE: ClassVar[str] = "iqn" _n_quantiles: int _n_greedy_quantiles: int _embed_size: int def __init__( self, bootstrap: bool = False, share_encoder: bool = False, n_quantiles: int = 64, n_greedy_quantiles: int = 32, embed_size: int = 64, ): super().__init__(bootstrap, share_encoder) self._n_quantiles = n_quantiles self._n_greedy_quantiles = n_greedy_quantiles self._embed_size = embed_size
[docs] def create_discrete( self, encoder: Encoder, action_size: int, ) -> DiscreteIQNQFunction: return DiscreteIQNQFunction( encoder=encoder, action_size=action_size, n_quantiles=self._n_quantiles, n_greedy_quantiles=self._n_greedy_quantiles, embed_size=self._embed_size, )
[docs] def create_continuous( self, encoder: EncoderWithAction, ) -> ContinuousIQNQFunction: return ContinuousIQNQFunction( encoder=encoder, n_quantiles=self._n_quantiles, n_greedy_quantiles=self._n_greedy_quantiles, embed_size=self._embed_size, )
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]: return { "bootstrap": self._bootstrap, "share_encoder": self._share_encoder, "n_quantiles": self._n_quantiles, "n_greedy_quantiles": self._n_greedy_quantiles, "embed_size": self._embed_size, }
@property def n_quantiles(self) -> int: return self._n_quantiles @property def n_greedy_quantiles(self) -> int: return self._n_greedy_quantiles @property def embed_size(self) -> int: return self._embed_size
[docs]class FQFQFunctionFactory(QFunctionFactory): """Fully parameterized Quantile Function Q function factory. References: * `Yang et al., Fully parameterized quantile function for distributional reinforcement learning. <https://arxiv.org/abs/1911.02140>`_ Args: bootstrap (bool): flag to bootstrap Q functions. share_encoder (bool): flag to share encoder over multiple Q functions. n_quantiles: the number of quantiles. embed_size: the embedding size. entropy_coeff: the coefficiency of entropy penalty term. """ TYPE: ClassVar[str] = "fqf" _n_quantiles: int _embed_size: int _entropy_coeff: float def __init__( self, bootstrap: bool = False, share_encoder: bool = False, n_quantiles: int = 32, embed_size: int = 64, entropy_coeff: float = 0.0, ): super().__init__(bootstrap, share_encoder) self._n_quantiles = n_quantiles self._embed_size = embed_size self._entropy_coeff = entropy_coeff
[docs] def create_discrete( self, encoder: Encoder, action_size: int, ) -> DiscreteFQFQFunction: return DiscreteFQFQFunction( encoder=encoder, action_size=action_size, n_quantiles=self._n_quantiles, embed_size=self._embed_size, entropy_coeff=self._entropy_coeff, )
[docs] def create_continuous( self, encoder: EncoderWithAction, ) -> ContinuousFQFQFunction: return ContinuousFQFQFunction( encoder=encoder, n_quantiles=self._n_quantiles, embed_size=self._embed_size, entropy_coeff=self._entropy_coeff, )
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]: return { "bootstrap": self._bootstrap, "share_encoder": self._share_encoder, "n_quantiles": self._n_quantiles, "embed_size": self._embed_size, "entropy_coeff": self._entropy_coeff, }
@property def n_quantiles(self) -> int: return self._n_quantiles @property def embed_size(self) -> int: return self._embed_size @property def entropy_coeff(self) -> float: return self._entropy_coeff
Q_FUNC_LIST: Dict[str, Type[QFunctionFactory]] = {} def register_q_func_factory(cls: Type[QFunctionFactory]) -> None: """Registers Q function factory class. Args: cls: Q function factory class inheriting ``QFunctionFactory``. """ is_registered = cls.TYPE in Q_FUNC_LIST assert not is_registered, "%s seems to be already registered" % cls.TYPE Q_FUNC_LIST[cls.TYPE] = cls def create_q_func_factory(name: str, **kwargs: Any) -> QFunctionFactory: """Returns registered Q function factory object. Args: name: registered Q function factory type name. kwargs: Q function arguments. Returns: Q function factory object. """ assert name in Q_FUNC_LIST, "%s seems not to be registered." % name factory = Q_FUNC_LIST[name](**kwargs) assert isinstance(factory, QFunctionFactory) return factory register_q_func_factory(MeanQFunctionFactory) register_q_func_factory(QRQFunctionFactory) register_q_func_factory(IQNQFunctionFactory) register_q_func_factory(FQFQFunctionFactory)