Source code for d3rlpy.q_functions

from abc import ABCMeta, abstractmethod
from d3rlpy.models.torch.q_functions import DiscreteMeanQFunction
from d3rlpy.models.torch.q_functions import DiscreteQRQFunction
from d3rlpy.models.torch.q_functions import DiscreteIQNQFunction
from d3rlpy.models.torch.q_functions import DiscreteFQFQFunction
from d3rlpy.models.torch.q_functions import ContinuousMeanQFunction
from d3rlpy.models.torch.q_functions import ContinuousQRQFunction
from d3rlpy.models.torch.q_functions import ContinuousIQNQFunction
from d3rlpy.models.torch.q_functions import ContinuousFQFQFunction


class QFunctionFactory(metaclass=ABCMeta):
    TYPE = 'none'

    @abstractmethod
    def create(self, encoder, action_size=None):
        """ Returns PyTorch's Q function module.

        Args:
            encoder (torch.nn.Module): an encoder module that processes
                the observation (and action in continuous action-space) to
                obtain feature representations.
            action_size (int): dimension of discrete action-space. If the
                action-space is continous, ``None`` will be passed.

        Returns:
            torch.nn.Module: Q function object.

        """
        pass

    def get_type(self):
        """ Returns Q function type.

        Returns:
            str: Q function type.

        """
        return self.TYPE

    @abstractmethod
    def get_params(self, deep=False):
        """ Returns Q function parameters.

        Returns:
            dict: Q function parameters.

        """
        pass


[docs]class MeanQFunctionFactory(QFunctionFactory): """ Standard Q function factory class. This is the standard Q function factory class. References: * `Mnih et al., Human-level control through deep reinforcement learning. <https://www.nature.com/articles/nature14236>`_ * `Lillicrap et al., Continuous control with deep reinforcement learning. <https://arxiv.org/abs/1509.02971>`_ """ TYPE = 'mean' def __init__(self): pass
[docs] def create(self, encoder, action_size=None): if action_size is None: q_func = ContinuousMeanQFunction(encoder) else: q_func = DiscreteMeanQFunction(encoder, action_size) return q_func
[docs] def get_params(self, deep=False): return {}
[docs]class QRQFunctionFactory(QFunctionFactory): """ Quantile Regression Q function factory class. References: * `Dabney et al., Distributional reinforcement learning with quantile regression. <https://arxiv.org/abs/1710.10044>`_ Args: n_quantiles (int): the number of quantiles. Attributes: n_quantiles (int): the number of quantiles. """ TYPE = 'qr' def __init__(self, n_quantiles=200): self.n_quantiles = n_quantiles
[docs] def create(self, encoder, action_size=None): if action_size is None: q_func = ContinuousQRQFunction(encoder, self.n_quantiles) else: q_func = DiscreteQRQFunction(encoder, action_size, self.n_quantiles) return q_func
[docs] def get_params(self, deep=False): return {'n_quantiles': self.n_quantiles}
[docs]class IQNQFunctionFactory(QFunctionFactory): """ Implicit Quantile Network Q function factory class. References: * `Dabney et al., Implicit quantile networks for distributional reinforcement learning. <https://arxiv.org/abs/1806.06923>`_ Args: n_quantiles (int): the number of quantiles. embed_size (int): the embedding size. Attributes: n_quantiles (int): the number of quantiles. embed_size (int): the embedding size. """ TYPE = 'iqn' def __init__(self, n_quantiles=32, embed_size=64): self.n_quantiles = n_quantiles self.embed_size = embed_size
[docs] def create(self, encoder, action_size=None): if action_size is None: q_func = ContinuousIQNQFunction(encoder, self.n_quantiles, self.embed_size) else: q_func = DiscreteIQNQFunction(encoder, action_size, self.n_quantiles, self.embed_size) return q_func
[docs] def get_params(self, deep=False): return {'n_quantiles': self.n_quantiles, 'embed_size': self.embed_size}
[docs]class FQFQFunctionFactory(QFunctionFactory): """ Fully parameterized Quantile Function Q function factory. References: * `Yang et al., Fully parameterized quantile function for distributional reinforcement learning. <https://arxiv.org/abs/1911.02140>`_ Args: n_quantiles (int): the number of quantiles. embed_size (int): the embedding size. entropy_coeff (float): the coefficiency of entropy penalty term. Attributes: n_quantiles (int): the number of quantiles. embed_size (int): the embedding size. entropy_coeff (float): the coefficiency of entropy penalty term. """ TYPE = 'fqf' def __init__(self, n_quantiles=32, embed_size=64, entropy_coeff=0.0): self.n_quantiles = n_quantiles self.embed_size = embed_size self.entropy_coeff = entropy_coeff
[docs] def create(self, encoder, action_size=None): if action_size is None: q_func = ContinuousFQFQFunction(encoder=encoder, n_quantiles=self.n_quantiles, embed_size=self.embed_size, entropy_coeff=self.entropy_coeff) else: q_func = DiscreteFQFQFunction(encoder=encoder, action_size=action_size, n_quantiles=self.n_quantiles, embed_size=self.embed_size, entropy_coeff=self.entropy_coeff) return q_func
[docs] def get_params(self, deep=False): return { 'n_quantiles': self.n_quantiles, 'embed_size': self.embed_size, 'entropy_coeff': self.entropy_coeff }
Q_FUNC_LIST = {} def register_q_func_factory(cls): """ Registers Q function factory class. Args: cls (type): Q function factory class inheriting ``QFunctionFactory``. """ is_registered = cls.TYPE in Q_FUNC_LIST assert not is_registered, '%s seems to be already registered' % cls.TYPE Q_FUNC_LIST[cls.TYPE] = cls def create_q_func_factory(name, **kwargs): """ Returns registered Q function factory object. Args: name (str): registered Q function factory type name. kwargs (any): Q function arguments. Returns: d3rlpy.q_functions.QFunctionFactory: Q function factory object. """ assert name in Q_FUNC_LIST, '%s seems not to be registered.' % name factory = Q_FUNC_LIST[name](**kwargs) assert isinstance(factory, QFunctionFactory) return factory register_q_func_factory(MeanQFunctionFactory) register_q_func_factory(QRQFunctionFactory) register_q_func_factory(IQNQFunctionFactory) register_q_func_factory(FQFQFunctionFactory)