Source code for d3rlpy.base

import dataclasses
import io
import pickle
from abc import ABCMeta, abstractmethod
from typing import BinaryIO, Generic, Optional, Type, TypeVar, Union

from gym.spaces import Box
from gymnasium.spaces import Box as GymnasiumBox
from typing_extensions import Self

from ._version import __version__
from .constants import IMPL_NOT_INITIALIZED_ERROR, ActionSpace
from .dataset import ReplayBuffer, detect_action_size_from_env
from .logging import LOG, D3RLPyLogger
from .preprocessing import (
    ActionScaler,
    ObservationScaler,
    RewardScaler,
    make_action_scaler_field,
    make_observation_scaler_field,
    make_reward_scaler_field,
)
from .serializable_config import DynamicConfig, generate_config_registration
from .torch_utility import Checkpointer, Modules
from .types import GymEnv, Shape

__all__ = [
    "DeviceArg",
    "ImplBase",
    "save_config",
    "dump_learnable",
    "load_learnable",
    "LearnableBase",
    "LearnableConfig",
    "LearnableConfigWithShape",
    "TImpl_co",
    "TConfig_co",
    "register_learnable",
]


DeviceArg = Optional[Union[bool, int, str]]
TImpl_co = TypeVar("TImpl_co", bound="ImplBase", covariant=True)
TConfig_co = TypeVar("TConfig_co", bound="LearnableConfig", covariant=True)


class ImplBase(metaclass=ABCMeta):
    _observation_shape: Shape
    _action_size: int
    _modules: Modules
    _checkpointer: Checkpointer
    _device: str

    def __init__(
        self,
        observation_shape: Shape,
        action_size: int,
        modules: Modules,
        device: str,
    ):
        self._observation_shape = observation_shape
        self._action_size = action_size
        self._modules = modules
        self._checkpointer = modules.create_checkpointer(device)
        self._device = device

    def save_model(self, f: BinaryIO) -> None:
        self._checkpointer.save(f)

    def load_model(self, f: BinaryIO) -> None:
        self._checkpointer.load(f)

    @property
    def observation_shape(self) -> Shape:
        return self._observation_shape

    @property
    def action_size(self) -> int:
        return self._action_size

    @property
    def device(self) -> str:
        return self._device

    @property
    def modules(self) -> Modules:
        return self._modules

    def wrap_models_by_ddp(self) -> None:
        self._modules = self._modules.wrap_models_by_ddp()

    def unwrap_models_by_ddp(self) -> None:
        self._modules = self._modules.unwrap_models_by_ddp()


@dataclasses.dataclass()
class LearnableConfig(DynamicConfig):
    batch_size: int = 256
    gamma: float = 0.99
    observation_scaler: Optional[ObservationScaler] = (
        make_observation_scaler_field()
    )
    action_scaler: Optional[ActionScaler] = make_action_scaler_field()
    reward_scaler: Optional[RewardScaler] = make_reward_scaler_field()

    def create(
        self, device: DeviceArg = False
    ) -> "LearnableBase[ImplBase, LearnableConfig]":
        r"""Returns algorithm object.

        Args:
            device (Union[int, str, bool]): device option. If the value is
                boolean and True, ``cuda:0`` will be used. If the value is
                integer, ``cuda:<device>`` will be used. If the value is string
                in torch device style, the specified device will be used.

        Returns:
            algorithm object.
        """
        raise NotImplementedError


register_learnable, make_learnable_field = generate_config_registration(
    LearnableConfig
)


@dataclasses.dataclass()
class LearnableConfigWithShape(DynamicConfig):
    observation_shape: Shape
    action_size: int
    config: LearnableConfig = make_learnable_field()

    def create(
        self, device: DeviceArg = False
    ) -> "LearnableBase[ImplBase, LearnableConfig]":
        algo = self.config.create(device)
        algo.create_impl(self.observation_shape, self.action_size)
        return algo


def save_config(
    algo: "LearnableBase[ImplBase, LearnableConfig]", logger: D3RLPyLogger
) -> None:
    assert algo.impl
    config = LearnableConfigWithShape(
        observation_shape=algo.impl.observation_shape,
        action_size=algo.impl.action_size,
        config=algo.config,
    )
    logger.add_params(config.serialize_to_dict())


def _process_device(value: DeviceArg) -> str:
    """Checks value and returns PyTorch target device.

    Returns:
        str: target device.
    """
    # isinstance cannot tell difference between bool and int
    if isinstance(value, bool):
        return "cuda:0" if value else "cpu:0"
    if isinstance(value, int):
        return f"cuda:{value}"
    if isinstance(value, str):
        return value
    if value is None:
        return "cpu:0"
    raise ValueError("This argument must be bool, int or str.")


def dump_learnable(
    algo: "LearnableBase[ImplBase, LearnableConfig]", fname: str
) -> None:
    assert algo.impl
    with open(fname, "wb") as f:
        torch_bytes = io.BytesIO()
        algo.impl.save_model(torch_bytes)
        config = LearnableConfigWithShape(
            observation_shape=algo.impl.observation_shape,
            action_size=algo.impl.action_size,
            config=algo.config,
        )
        obj = {
            "torch": torch_bytes.getvalue(),
            "config": config.serialize(),
            "version": __version__,
        }
        pickle.dump(obj, f)


def load_learnable(
    fname: str, device: DeviceArg = None
) -> "LearnableBase[ImplBase, LearnableConfig]":
    with open(fname, "rb") as f:
        obj = pickle.load(f)
        if obj["version"] != __version__:
            LOG.warning(
                "There might be incompatibility because of version mismatch.",
                current_version=__version__,
                saved_version=obj["version"],
            )
        config = LearnableConfigWithShape.deserialize(obj["config"])
        algo = config.create(device)
        assert algo.impl
        algo.impl.load_model(io.BytesIO(obj["torch"]))
    return algo


[docs]class LearnableBase(Generic[TImpl_co, TConfig_co], metaclass=ABCMeta): _config: TConfig_co _device: str _impl: Optional[TImpl_co] _grad_step: int def __init__( self, config: TConfig_co, device: DeviceArg, impl: Optional[TImpl_co] = None, ): if self.get_action_type() == ActionSpace.DISCRETE: assert ( config.action_scaler is None ), "action_scaler cannot be used with discrete action-space algorithms." self._config = config self._device = _process_device(device) self._impl = impl self._grad_step = 0
[docs] def save(self, fname: str) -> None: """Saves paired data of neural network parameters and serialized config. .. code-block:: python algo.save('model.d3') # reconstruct everything algo2 = d3rlpy.load_learnable("model.d3", device="cuda:0") Args: fname: destination file path. """ assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR dump_learnable(self, fname)
[docs] def save_model(self, fname: str) -> None: """Saves neural network parameters. .. code-block:: python algo.save_model('model.pt') Args: fname: destination file path. """ assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR with open(fname, "wb") as f: self._impl.save_model(f)
[docs] def load_model(self, fname: str) -> None: """Load neural network parameters. .. code-block:: python algo.load_model('model.pt') Args: fname: source file path. """ assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR with open(fname, "rb") as f: self._impl.load_model(f)
[docs] @classmethod def from_json( cls: Type[Self], fname: str, device: DeviceArg = False ) -> Self: r"""Construct algorithm from params.json file. .. code-block:: python from d3rlpy.algos import CQL cql = CQL.from_json("<path-to-json>", device='cuda:0') Args: fname: path to params.json device (Union[int, str, bool]): device option. If the value is boolean and True, ``cuda:0`` will be used. If the value is integer, ``cuda:<device>`` will be used. If the value is string in torch device style, the specified device will be used. Returns: algorithm object. """ config = LearnableConfigWithShape.deserialize_from_file(fname) return config.create(device) # type: ignore
[docs] def create_impl(self, observation_shape: Shape, action_size: int) -> None: """Instantiate implementation objects with the dataset shapes. This method will be used internally when `fit` method is called. Args: observation_shape: observation shape. action_size: dimension of action-space. """ if self._impl: LOG.warn("Parameters will be reinitialized.") self.inner_create_impl(observation_shape, action_size)
@abstractmethod def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: pass
[docs] def build_with_dataset(self, dataset: ReplayBuffer) -> None: """Instantiate implementation object with ReplayBuffer object. Args: dataset: dataset. """ observation_shape = ( dataset.sample_transition().observation_signature.shape[0] ) self.create_impl(observation_shape, dataset.dataset_info.action_size)
[docs] def build_with_env(self, env: GymEnv) -> None: """Instantiate implementation object with OpenAI Gym object. Args: env: gym-like environment. """ assert isinstance( env.observation_space, (Box, GymnasiumBox) ), f"Unsupported observation space: {type(env.observation_space)}" observation_shape = env.observation_space.shape action_size = detect_action_size_from_env(env) self.create_impl(observation_shape, action_size)
[docs] def get_action_type(self) -> ActionSpace: """Returns action type (continuous or discrete). Returns: action type. """ raise NotImplementedError
@property def config(self) -> TConfig_co: """Config. Returns: LearnableConfig: config. """ return self._config @property def batch_size(self) -> int: """Batch size to train. Returns: int: batch size. """ return self._config.batch_size @property def gamma(self) -> float: """Discount factor. Returns: float: discount factor. """ return self._config.gamma @property def observation_scaler(self) -> Optional[ObservationScaler]: """Preprocessing observation scaler. Returns: Optional[ObservationScaler]: preprocessing observation scaler. """ return self._config.observation_scaler @property def action_scaler(self) -> Optional[ActionScaler]: """Preprocessing action scaler. Returns: Optional[ActionScaler]: preprocessing action scaler. """ return self._config.action_scaler @property def reward_scaler(self) -> Optional[RewardScaler]: """Preprocessing reward scaler. Returns: Optional[RewardScaler]: preprocessing reward scaler. """ return self._config.reward_scaler @property def impl(self) -> Optional[TImpl_co]: """Implementation object. Returns: Optional[ImplBase]: implementation object. """ return self._impl @property def observation_shape(self) -> Optional[Shape]: """Observation shape. Returns: Optional[Sequence[int]]: observation shape. """ if self._impl: return self._impl.observation_shape return None @property def action_size(self) -> Optional[int]: """Action size. Returns: Optional[int]: action size. """ if self._impl: return self._impl.action_size return None @property def grad_step(self) -> int: """Total gradient step counter. This value will keep counting after ``fit`` and ``fit_online`` methods finish. Returns: total gradient step counter. """ return self._grad_step
[docs] def set_grad_step(self, grad_step: int) -> None: """Set total gradient step counter. This method can be used to restart the middle of training with an arbitrary gradient step counter, which has effects on periodic functions such as the target update. Args: grad_step: total gradient step counter. """ self._grad_step = grad_step