Source code for d3rlpy.encoders

import copy
import torch

from abc import ABCMeta, abstractmethod
from d3rlpy.models.torch.encoders import PixelEncoder
from d3rlpy.models.torch.encoders import PixelEncoderWithAction
from d3rlpy.models.torch.encoders import VectorEncoder
from d3rlpy.models.torch.encoders import VectorEncoderWithAction


def _create_activation(activation_type):
    if activation_type == 'relu':
        return torch.relu
    elif activation_type == 'tanh':
        return torch.tanh
    elif activation_type == 'elu':
        return torch.elu
    elif activation_type == 'swish':
        return lambda x: x * torch.sigmoid(x)
    raise ValueError('invalid activation_type.')


class EncoderFactory(metaclass=ABCMeta):
    TYPE = 'none'

    @abstractmethod
    def create(self,
               observation_shape,
               action_size=None,
               discrete_action=False):
        """ Returns PyTorch's enocder module.

        Args:
            observation_shape (tuple): observation shape.
            action_size (int): action size. If None, the encoder does not take
                action as input.
            discrete_action (bool): flag if action-space is discrete.

        Returns:
            torch.nn.Module: an enocder object.

        """
        pass

    def get_type(self):
        """ Returns encoder type.

        Returns:
            str: encoder type.

        """
        return self.TYPE

    @abstractmethod
    def get_params(self, deep=False):
        """ Returns encoder parameters.

        Args:
            deep (bool): flag to deeply copy the parameters.

        Returns:
            dict: encoder parameters.

        """
        pass


[docs]class PixelEncoderFactory(EncoderFactory): """ Pixel encoder factory class. This is the default encoder factory for image observation. Args: filters (list): list of tuples consisting with ``(filter_size, kernel_size, stride)``. If None, ``Nature DQN``-based architecture is used. feature_size (int): the last linear layer size. activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. Attributes: filters (list): list of tuples consisting with ``(filter_size, kernel_size, stride)``. feature_size (int): the last linear layer size. activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. """ TYPE = 'pixel' def __init__(self, filters=None, feature_size=512, activation='relu', use_batch_norm=False): if filters is None: self.filters = [(32, 8, 4), (64, 4, 2), (64, 3, 1)] else: self.filters = filters self.feature_size = feature_size self.activation = activation self.use_batch_norm = use_batch_norm
[docs] def create(self, observation_shape, action_size=None, discrete_action=False): assert len(observation_shape) == 3 activation_fn = _create_activation(self.activation) if action_size is not None: encoder = PixelEncoderWithAction( observation_shape=observation_shape, action_size=action_size, filters=self.filters, feature_size=self.feature_size, use_batch_norm=self.use_batch_norm, discrete_action=discrete_action, activation=activation_fn) else: encoder = PixelEncoder(observation_shape=observation_shape, filters=self.filters, feature_size=self.feature_size, use_batch_norm=self.use_batch_norm, activation=activation_fn) return encoder
[docs] def get_params(self, deep=False): if deep: filters = copy.deepcopy(self.filters) else: filters = self.filters params = { 'filters': filters, 'feature_size': self.feature_size, 'activation': self.activation, 'use_batch_norm': self.use_batch_norm } return params
[docs]class VectorEncoderFactory(EncoderFactory): """ Vector encoder factory class. This is the default encoder factory for vector observation. Args: hidden_units (list): list of hidden unit sizes. If ``None``, the standard architecture with ``[256, 256]`` is used. activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. use_dense (bool): flag to use DenseNet architecture. Attributes: hidden_units (list): list of hidden unit sizes. activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. use_dense (bool): flag to use DenseNet architecture. """ TYPE = 'vector' def __init__(self, hidden_units=None, activation='relu', use_batch_norm=False, use_dense=False): if hidden_units is None: self.hidden_units = [256, 256] else: self.hidden_units = hidden_units self.activation = activation self.use_batch_norm = use_batch_norm self.use_dense = use_dense
[docs] def create(self, observation_shape, action_size=None, discrete_action=False): assert len(observation_shape) == 1 activation_fn = _create_activation(self.activation) if action_size is not None: encoder = VectorEncoderWithAction( observation_shape=observation_shape, action_size=action_size, hidden_units=self.hidden_units, use_batch_norm=self.use_batch_norm, use_dense=self.use_dense, discrete_action=discrete_action, activation=activation_fn) else: encoder = VectorEncoder(observation_shape=observation_shape, hidden_units=self.hidden_units, use_batch_norm=self.use_batch_norm, use_dense=self.use_dense, activation=activation_fn) return encoder
[docs] def get_params(self, deep=False): if deep: hidden_units = copy.deepcopy(self.hidden_units) else: hidden_units = self.hidden_units params = { 'hidden_units': hidden_units, 'activation': self.activation, 'use_batch_norm': self.use_batch_norm, 'use_dense': self.use_dense } return params
[docs]class DefaultEncoderFactory(EncoderFactory): """ Default encoder factory class. This encoder factory returns an encoder based on observation shape. Args: activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. Attributes: activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. """ TYPE = 'default' def __init__(self, activation='relu', use_batch_norm=False): self.activation = activation self.use_batch_norm = use_batch_norm
[docs] def create(self, observation_shape, action_size=None, discrete_action=False): if len(observation_shape) == 3: factory = PixelEncoderFactory(activation=self.activation, use_batch_norm=self.use_batch_norm) else: factory = VectorEncoderFactory(activation=self.activation, use_batch_norm=self.use_batch_norm) return factory.create(observation_shape, action_size, discrete_action)
[docs] def get_params(self, deep=False): return { 'activation': self.activation, 'use_batch_norm': self.use_batch_norm }
[docs]class DenseEncoderFactory(EncoderFactory): """ DenseNet encoder factory class. This is an alias for DenseNet architecture proposed in D2RL. This class does exactly same as follows. .. code-block:: python from d3rlpy.encoders import VectorEncoderFactory factory = VectorEncoderFactory(hidden_units=[256, 256, 256, 256], use_dense=True) For now, this only supports vector observations. References: * `Sinha et al., D2RL: Deep Dense Architectures in Reinforcement Learning. <https://arxiv.org/abs/2010.09163>`_ Args: activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. Attributes: activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. """ TYPE = 'dense' def __init__(self, activation='relu', use_batch_norm=False): self.activation = activation self.use_batch_norm = use_batch_norm
[docs] def create(self, observation_shape, action_size=None, discrete_action=False): if len(observation_shape) == 3: raise NotImplementedError('pixel observation is not supported.') else: factory = VectorEncoderFactory(hidden_units=[256, 256, 256, 256], activation=self.activation, use_dense=True, use_batch_norm=self.use_batch_norm) return factory.create(observation_shape, action_size, discrete_action)
[docs] def get_params(self, deep=False): return { 'activation': self.activation, 'use_batch_norm': self.use_batch_norm }
ENCODER_LIST = {} def register_encoder_factory(cls): """ Registers encoder factory class. Args: cls (type): encoder factory class inheriting ``EncoderFactory``. """ is_registered = cls.TYPE in ENCODER_LIST assert not is_registered, '%s seems to be already registered' % cls.TYPE ENCODER_LIST[cls.TYPE] = cls def create_encoder_factory(name, **kwargs): """ Returns registered encoder factory object. Args: name (str): regsitered encoder factory type name. kwargs (any): encoder arguments. Returns: d3rlpy.encoders.EncoderFactory: encoder factory object. """ assert name in ENCODER_LIST, '%s seems not to be registered.' % name factory = ENCODER_LIST[name](**kwargs) assert isinstance(factory, EncoderFactory) return factory register_encoder_factory(VectorEncoderFactory) register_encoder_factory(PixelEncoderFactory) register_encoder_factory(DefaultEncoderFactory) register_encoder_factory(DenseEncoderFactory)