Source code for d3rlpy.models.q_functions

import dataclasses
from typing import Tuple

from ..serializable_config import DynamicConfig, generate_config_registration
from .torch import (
    ContinuousIQNQFunction,
    ContinuousIQNQFunctionForwarder,
    ContinuousMeanQFunction,
    ContinuousMeanQFunctionForwarder,
    ContinuousQFunction,
    ContinuousQFunctionForwarder,
    ContinuousQRQFunction,
    ContinuousQRQFunctionForwarder,
    DiscreteIQNQFunction,
    DiscreteIQNQFunctionForwarder,
    DiscreteMeanQFunction,
    DiscreteMeanQFunctionForwarder,
    DiscreteQFunction,
    DiscreteQFunctionForwarder,
    DiscreteQRQFunction,
    DiscreteQRQFunctionForwarder,
    Encoder,
    EncoderWithAction,
)

__all__ = [
    "QFunctionFactory",
    "MeanQFunctionFactory",
    "QRQFunctionFactory",
    "IQNQFunctionFactory",
    "make_q_func_field",
]


@dataclasses.dataclass()
class QFunctionFactory(DynamicConfig):
    share_encoder: bool = False

    def create_discrete(
        self, encoder: Encoder, hidden_size: int, action_size: int
    ) -> Tuple[DiscreteQFunction, DiscreteQFunctionForwarder]:
        """Returns PyTorch's Q function module.

        Args:
            encoder: Encoder that processes the observation to
                obtain feature representations.
            hidden_size: Dimension of encoder output.
            action_size: Dimension of discrete action-space.

        Returns:
            Tuple of discrete Q function and its forwarder.
        """
        raise NotImplementedError

    def create_continuous(
        self, encoder: EncoderWithAction, hidden_size: int
    ) -> Tuple[ContinuousQFunction, ContinuousQFunctionForwarder]:
        """Returns PyTorch's Q function module.

        Args:
            encoder: Encoder module that processes the observation and
                action to obtain feature representations.
            hidden_size: Dimension of encoder output.

        Returns:
            Tuple of continuous Q function and its forwarder.
        """
        raise NotImplementedError

    @staticmethod
    def get_type() -> str:
        """Returns Q function type.

        Returns:
            Q function type.
        """
        raise NotImplementedError


[docs]@dataclasses.dataclass() class MeanQFunctionFactory(QFunctionFactory): """Standard Q function factory class. This is the standard Q function factory class. References: * `Mnih et al., Human-level control through deep reinforcement learning. <https://www.nature.com/articles/nature14236>`_ * `Lillicrap et al., Continuous control with deep reinforcement learning. <https://arxiv.org/abs/1509.02971>`_ Args: share_encoder (bool): flag to share encoder over multiple Q functions. """
[docs] def create_discrete( self, encoder: Encoder, hidden_size: int, action_size: int, ) -> Tuple[DiscreteMeanQFunction, DiscreteMeanQFunctionForwarder]: q_func = DiscreteMeanQFunction(encoder, hidden_size, action_size) forwarder = DiscreteMeanQFunctionForwarder(q_func, action_size) return q_func, forwarder
[docs] def create_continuous( self, encoder: EncoderWithAction, hidden_size: int, ) -> Tuple[ContinuousMeanQFunction, ContinuousMeanQFunctionForwarder]: q_func = ContinuousMeanQFunction(encoder, hidden_size) forwarder = ContinuousMeanQFunctionForwarder(q_func) return q_func, forwarder
[docs] @staticmethod def get_type() -> str: return "mean"
[docs]@dataclasses.dataclass() class QRQFunctionFactory(QFunctionFactory): """Quantile Regression Q function factory class. References: * `Dabney et al., Distributional reinforcement learning with quantile regression. <https://arxiv.org/abs/1710.10044>`_ Args: share_encoder (bool): flag to share encoder over multiple Q functions. n_quantiles: the number of quantiles. """ n_quantiles: int = 32
[docs] def create_discrete( self, encoder: Encoder, hidden_size: int, action_size: int ) -> Tuple[DiscreteQRQFunction, DiscreteQRQFunctionForwarder]: q_func = DiscreteQRQFunction( encoder=encoder, hidden_size=hidden_size, action_size=action_size, n_quantiles=self.n_quantiles, ) forwarder = DiscreteQRQFunctionForwarder(q_func, self.n_quantiles) return q_func, forwarder
[docs] def create_continuous( self, encoder: EncoderWithAction, hidden_size: int, ) -> Tuple[ContinuousQRQFunction, ContinuousQRQFunctionForwarder]: q_func = ContinuousQRQFunction( encoder=encoder, hidden_size=hidden_size, n_quantiles=self.n_quantiles, ) forwarder = ContinuousQRQFunctionForwarder(q_func, self.n_quantiles) return q_func, forwarder
[docs] @staticmethod def get_type() -> str: return "qr"
[docs]@dataclasses.dataclass() class IQNQFunctionFactory(QFunctionFactory): """Implicit Quantile Network Q function factory class. References: * `Dabney et al., Implicit quantile networks for distributional reinforcement learning. <https://arxiv.org/abs/1806.06923>`_ Args: share_encoder (bool): flag to share encoder over multiple Q functions. n_quantiles: the number of quantiles. n_greedy_quantiles: the number of quantiles for inference. embed_size: the embedding size. """ n_quantiles: int = 64 n_greedy_quantiles: int = 32 embed_size: int = 64
[docs] def create_discrete( self, encoder: Encoder, hidden_size: int, action_size: int, ) -> Tuple[DiscreteIQNQFunction, DiscreteIQNQFunctionForwarder]: q_func = DiscreteIQNQFunction( encoder=encoder, hidden_size=hidden_size, action_size=action_size, n_quantiles=self.n_quantiles, n_greedy_quantiles=self.n_greedy_quantiles, embed_size=self.embed_size, ) forwarder = DiscreteIQNQFunctionForwarder(q_func, self.n_quantiles) return q_func, forwarder
[docs] def create_continuous( self, encoder: EncoderWithAction, hidden_size: int ) -> Tuple[ContinuousIQNQFunction, ContinuousIQNQFunctionForwarder]: q_func = ContinuousIQNQFunction( encoder=encoder, hidden_size=hidden_size, n_quantiles=self.n_quantiles, n_greedy_quantiles=self.n_greedy_quantiles, embed_size=self.embed_size, ) forwarder = ContinuousIQNQFunctionForwarder( q_func, self.n_greedy_quantiles ) return q_func, forwarder
[docs] @staticmethod def get_type() -> str: return "iqn"
register_q_func_factory, make_q_func_field = generate_config_registration( QFunctionFactory, lambda: MeanQFunctionFactory() ) register_q_func_factory(MeanQFunctionFactory) register_q_func_factory(QRQFunctionFactory) register_q_func_factory(IQNQFunctionFactory)