Source code for d3rlpy.models.encoders

from dataclasses import dataclass, field
from typing import List, Optional, Union

from ..dataset import cast_flat_shape
from ..serializable_config import DynamicConfig, generate_config_registration
from ..types import Shape
from .torch import (
    Encoder,
    EncoderWithAction,
    PixelEncoder,
    PixelEncoderWithAction,
    VectorEncoder,
    VectorEncoderWithAction,
)
from .utility import create_activation

__all__ = [
    "EncoderFactory",
    "PixelEncoderFactory",
    "VectorEncoderFactory",
    "DefaultEncoderFactory",
    "register_encoder_factory",
    "make_encoder_field",
]


class EncoderFactory(DynamicConfig):
    def create(self, observation_shape: Shape) -> Encoder:
        """Returns PyTorch's state enocder module.

        Args:
            observation_shape: observation shape.

        Returns:
            an enocder object.
        """
        raise NotImplementedError

    def create_with_action(
        self,
        observation_shape: Shape,
        action_size: int,
        discrete_action: bool = False,
    ) -> EncoderWithAction:
        """Returns PyTorch's state-action enocder module.

        Args:
            observation_shape: observation shape.
            action_size: action size. If None, the encoder does not take
                action as input.
            discrete_action: flag if action-space is discrete.

        Returns:
            an enocder object.
        """
        raise NotImplementedError


[docs]@dataclass() class PixelEncoderFactory(EncoderFactory): """Pixel encoder factory class. This is the default encoder factory for image observation. Args: filters (list): List of tuples consisting with ``(filter_size, kernel_size, stride)``. If None, ``Nature DQN``-based architecture is used. feature_size (int): Last linear layer size. activation (str): Activation function name. use_batch_norm (bool): Flag to insert batch normalization layers. dropout_rate (float): Dropout probability. exclude_last_activation (bool): Flag to exclude activation function at the last layer. last_activation (str): Activation function name for the last layer. """ filters: List[List[int]] = field( default_factory=lambda: [[32, 8, 4], [64, 4, 2], [64, 3, 1]] ) feature_size: int = 512 activation: str = "relu" use_batch_norm: bool = False dropout_rate: Optional[float] = None exclude_last_activation: bool = False last_activation: Optional[str] = None
[docs] def create(self, observation_shape: Shape) -> PixelEncoder: assert len(observation_shape) == 3 return PixelEncoder( observation_shape=cast_flat_shape(observation_shape), filters=self.filters, feature_size=self.feature_size, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, last_activation=( create_activation(self.last_activation) if self.last_activation else None ), )
[docs] def create_with_action( self, observation_shape: Shape, action_size: int, discrete_action: bool = False, ) -> PixelEncoderWithAction: assert len(observation_shape) == 3 return PixelEncoderWithAction( observation_shape=cast_flat_shape(observation_shape), action_size=action_size, filters=self.filters, feature_size=self.feature_size, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, discrete_action=discrete_action, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, last_activation=( create_activation(self.last_activation) if self.last_activation else None ), )
[docs] @staticmethod def get_type() -> str: return "pixel"
[docs]@dataclass() class VectorEncoderFactory(EncoderFactory): """Vector encoder factory class. This is the default encoder factory for vector observation. Args: hidden_units (list): List of hidden unit sizes. If ``None``, the standard architecture with ``[256, 256]`` is used. activation (str): activation function name. use_batch_norm (bool): Flag to insert batch normalization layers. dropout_rate (float): Dropout probability. exclude_last_activation (bool): Flag to exclude activation function at the last layer. last_activation (str): Activation function name for the last layer. """ hidden_units: List[int] = field(default_factory=lambda: [256, 256]) activation: str = "relu" use_batch_norm: bool = False dropout_rate: Optional[float] = None exclude_last_activation: bool = False last_activation: Optional[str] = None
[docs] def create(self, observation_shape: Shape) -> VectorEncoder: assert len(observation_shape) == 1 return VectorEncoder( observation_shape=cast_flat_shape(observation_shape), hidden_units=self.hidden_units, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, last_activation=( create_activation(self.last_activation) if self.last_activation else None ), )
[docs] def create_with_action( self, observation_shape: Shape, action_size: int, discrete_action: bool = False, ) -> VectorEncoderWithAction: assert len(observation_shape) == 1 return VectorEncoderWithAction( observation_shape=cast_flat_shape(observation_shape), action_size=action_size, hidden_units=self.hidden_units, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, discrete_action=discrete_action, activation=create_activation(self.activation), exclude_last_activation=self.exclude_last_activation, last_activation=( create_activation(self.last_activation) if self.last_activation else None ), )
[docs] @staticmethod def get_type() -> str: return "vector"
[docs]@dataclass() class DefaultEncoderFactory(EncoderFactory): """Default encoder factory class. This encoder factory returns an encoder based on observation shape. Args: activation (str): activation function name. use_batch_norm (bool): flag to insert batch normalization layers. dropout_rate (float): dropout probability. """ activation: str = "relu" use_batch_norm: bool = False dropout_rate: Optional[float] = None
[docs] def create(self, observation_shape: Shape) -> Encoder: factory: Union[PixelEncoderFactory, VectorEncoderFactory] if len(observation_shape) == 3: factory = PixelEncoderFactory( activation=self.activation, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, ) else: factory = VectorEncoderFactory( activation=self.activation, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, ) return factory.create(observation_shape)
[docs] def create_with_action( self, observation_shape: Shape, action_size: int, discrete_action: bool = False, ) -> EncoderWithAction: factory: Union[PixelEncoderFactory, VectorEncoderFactory] if len(observation_shape) == 3: factory = PixelEncoderFactory( activation=self.activation, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, ) else: factory = VectorEncoderFactory( activation=self.activation, use_batch_norm=self.use_batch_norm, dropout_rate=self.dropout_rate, ) return factory.create_with_action( observation_shape, action_size, discrete_action )
[docs] @staticmethod def get_type() -> str: return "default"
register_encoder_factory, make_encoder_field = generate_config_registration( EncoderFactory, lambda: DefaultEncoderFactory() ) register_encoder_factory(VectorEncoderFactory) register_encoder_factory(PixelEncoderFactory) register_encoder_factory(DefaultEncoderFactory)