Source code for d3rlpy.models.encoders

import copy
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, Type
from typing import Sequence

import torch

from .torch import Encoder, EncoderWithAction
from .torch import PixelEncoder
from .torch import PixelEncoderWithAction
from .torch import VectorEncoder
from .torch import VectorEncoderWithAction


def _create_activation(
    activation_type: str,
) -> Callable[[torch.Tensor], torch.Tensor]:
    if activation_type == "relu":
        return torch.relu  # type: ignore
    elif activation_type == "tanh":
        return torch.tanh  # type: ignore
    elif activation_type == "swish":
        return lambda x: x * torch.sigmoid(x)
    raise ValueError("invalid activation_type.")


class EncoderFactory:
    TYPE: ClassVar[str] = "none"

    def create(self, observation_shape: Sequence[int]) -> Encoder:
        """Returns PyTorch's state enocder module.

        Args:
            observation_shape: observation shape.

        Returns:
            an enocder object.

        """
        raise NotImplementedError

    def create_with_action(
        self,
        observation_shape: Sequence[int],
        action_size: int,
        discrete_action: bool = False,
    ) -> EncoderWithAction:
        """Returns PyTorch's state-action enocder module.

        Args:
            observation_shape: observation shape.
            action_size: action size. If None, the encoder does not take
                action as input.
            discrete_action: flag if action-space is discrete.

        Returns:
            an enocder object.

        """
        raise NotImplementedError

    def get_type(self) -> str:
        """Returns encoder type.

        Returns:
            encoder type.

        """
        return self.TYPE

    def get_params(self, deep: bool = False) -> Dict[str, Any]:
        """Returns encoder parameters.

        Args:
            deep: flag to deeply copy the parameters.

        Returns:
            encoder parameters.

        """
        raise NotImplementedError


[docs]class PixelEncoderFactory(EncoderFactory): """Pixel encoder factory class. This is the default encoder factory for image observation. Args: filters (list): list of tuples consisting with ``(filter_size, kernel_size, stride)``. If None, ``Nature DQN``-based architecture is used. feature_size (int): the last linear layer size. activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. """ TYPE: ClassVar[str] = "pixel" _filters: List[Sequence[int]] _feature_size: int _activation: str _use_batch_norm: bool def __init__( self, filters: Optional[List[Sequence[int]]] = None, feature_size: int = 512, activation: str = "relu", use_batch_norm: bool = False, ): if filters is None: self._filters = [(32, 8, 4), (64, 4, 2), (64, 3, 1)] else: self._filters = filters self._feature_size = feature_size self._activation = activation self._use_batch_norm = use_batch_norm
[docs] def create(self, observation_shape: Sequence[int]) -> PixelEncoder: assert len(observation_shape) == 3 return PixelEncoder( observation_shape=observation_shape, filters=self._filters, feature_size=self._feature_size, use_batch_norm=self._use_batch_norm, activation=_create_activation(self._activation), )
[docs] def create_with_action( self, observation_shape: Sequence[int], action_size: int, discrete_action: bool = False, ) -> PixelEncoderWithAction: assert len(observation_shape) == 3 return PixelEncoderWithAction( observation_shape=observation_shape, action_size=action_size, filters=self._filters, feature_size=self._feature_size, use_batch_norm=self._use_batch_norm, discrete_action=discrete_action, activation=_create_activation(self._activation), )
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]: if deep: filters = copy.deepcopy(self._filters) else: filters = self._filters params = { "filters": filters, "feature_size": self._feature_size, "activation": self._activation, "use_batch_norm": self._use_batch_norm, } return params
[docs]class VectorEncoderFactory(EncoderFactory): """Vector encoder factory class. This is the default encoder factory for vector observation. Args: hidden_units (list): list of hidden unit sizes. If ``None``, the standard architecture with ``[256, 256]`` is used. activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. use_dense (bool): flag to use DenseNet architecture. """ TYPE: ClassVar[str] = "vector" _hidden_units: Sequence[int] _activation: str _use_batch_norm: bool _use_dense: bool def __init__( self, hidden_units: Optional[Sequence[int]] = None, activation: str = "relu", use_batch_norm: bool = False, use_dense: bool = False, ): if hidden_units is None: self._hidden_units = [256, 256] else: self._hidden_units = hidden_units self._activation = activation self._use_batch_norm = use_batch_norm self._use_dense = use_dense
[docs] def create(self, observation_shape: Sequence[int]) -> VectorEncoder: assert len(observation_shape) == 1 return VectorEncoder( observation_shape=observation_shape, hidden_units=self._hidden_units, use_batch_norm=self._use_batch_norm, use_dense=self._use_dense, activation=_create_activation(self._activation), )
[docs] def create_with_action( self, observation_shape: Sequence[int], action_size: int, discrete_action: bool = False, ) -> VectorEncoderWithAction: assert len(observation_shape) == 1 return VectorEncoderWithAction( observation_shape=observation_shape, action_size=action_size, hidden_units=self._hidden_units, use_batch_norm=self._use_batch_norm, use_dense=self._use_dense, discrete_action=discrete_action, activation=_create_activation(self._activation), )
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]: if deep: hidden_units = copy.deepcopy(self._hidden_units) else: hidden_units = self._hidden_units params = { "hidden_units": hidden_units, "activation": self._activation, "use_batch_norm": self._use_batch_norm, "use_dense": self._use_dense, } return params
[docs]class DefaultEncoderFactory(EncoderFactory): """Default encoder factory class. This encoder factory returns an encoder based on observation shape. Args: activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. """ TYPE: ClassVar[str] = "default" _activation: str _use_batch_norm: bool def __init__(self, activation: str = "relu", use_batch_norm: bool = False): self._activation = activation self._use_batch_norm = use_batch_norm
[docs] def create(self, observation_shape: Sequence[int]) -> Encoder: factory: Union[PixelEncoderFactory, VectorEncoderFactory] if len(observation_shape) == 3: factory = PixelEncoderFactory( activation=self._activation, use_batch_norm=self._use_batch_norm ) else: factory = VectorEncoderFactory( activation=self._activation, use_batch_norm=self._use_batch_norm ) return factory.create(observation_shape)
[docs] def create_with_action( self, observation_shape: Sequence[int], action_size: int, discrete_action: bool = False, ) -> EncoderWithAction: factory: Union[PixelEncoderFactory, VectorEncoderFactory] if len(observation_shape) == 3: factory = PixelEncoderFactory( activation=self._activation, use_batch_norm=self._use_batch_norm ) else: factory = VectorEncoderFactory( activation=self._activation, use_batch_norm=self._use_batch_norm ) return factory.create_with_action( observation_shape, action_size, discrete_action )
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]: return { "activation": self._activation, "use_batch_norm": self._use_batch_norm, }
[docs]class DenseEncoderFactory(EncoderFactory): """DenseNet encoder factory class. This is an alias for DenseNet architecture proposed in D2RL. This class does exactly same as follows. .. code-block:: python from d3rlpy.encoders import VectorEncoderFactory factory = VectorEncoderFactory(hidden_units=[256, 256, 256, 256], use_dense=True) For now, this only supports vector observations. References: * `Sinha et al., D2RL: Deep Dense Architectures in Reinforcement Learning. <https://arxiv.org/abs/2010.09163>`_ Args: activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. """ TYPE: ClassVar[str] = "dense" _activation: str _use_batch_norm: bool def __init__(self, activation: str = "relu", use_batch_norm: bool = False): self._activation = activation self._use_batch_norm = use_batch_norm
[docs] def create(self, observation_shape: Sequence[int]) -> VectorEncoder: if len(observation_shape) == 3: raise NotImplementedError("pixel observation is not supported.") factory = VectorEncoderFactory( hidden_units=[256, 256, 256, 256], activation=self._activation, use_dense=True, use_batch_norm=self._use_batch_norm, ) return factory.create(observation_shape)
[docs] def create_with_action( self, observation_shape: Sequence[int], action_size: int, discrete_action: bool = False, ) -> VectorEncoderWithAction: if len(observation_shape) == 3: raise NotImplementedError("pixel observation is not supported.") factory = VectorEncoderFactory( hidden_units=[256, 256, 256, 256], activation=self._activation, use_dense=True, use_batch_norm=self._use_batch_norm, ) return factory.create_with_action( observation_shape, action_size, discrete_action )
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]: return { "activation": self._activation, "use_batch_norm": self._use_batch_norm, }
ENCODER_LIST: Dict[str, Type[EncoderFactory]] = {} def register_encoder_factory(cls: Type[EncoderFactory]) -> None: """Registers encoder factory class. Args: cls: encoder factory class inheriting ``EncoderFactory``. """ is_registered = cls.TYPE in ENCODER_LIST assert not is_registered, "%s seems to be already registered" % cls.TYPE ENCODER_LIST[cls.TYPE] = cls def create_encoder_factory(name: str, **kwargs: Any) -> EncoderFactory: """Returns registered encoder factory object. Args: name: regsitered encoder factory type name. kwargs: encoder arguments. Returns: encoder factory object. """ assert name in ENCODER_LIST, "%s seems not to be registered." % name factory = ENCODER_LIST[name](**kwargs) # type: ignore assert isinstance(factory, EncoderFactory) return factory register_encoder_factory(VectorEncoderFactory) register_encoder_factory(PixelEncoderFactory) register_encoder_factory(DefaultEncoderFactory) register_encoder_factory(DenseEncoderFactory)