Source code for d3rlpy.algos.qlearning.random_policy

import dataclasses
from typing import Dict

import numpy as np

from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...torch_utility import TorchMiniBatch
from ...types import NDArray, Observation, Shape
from .base import QLearningAlgoBase

__all__ = [

[docs]@dataclasses.dataclass() class RandomPolicyConfig(LearnableConfig): r"""Random Policy for continuous control algorithm. This is designed for data collection and lightweight interaction tests. ``fit`` and ``fit_online`` methods will raise exceptions. Args: action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. distribution (str): Random distribution. Available options are ``['uniform', 'normal']``. normal_std (float): Standard deviation of the normal distribution. This is only used when ``distribution='normal'``. """ distribution: str = "uniform" normal_std: float = 1.0
[docs] def create(self, device: DeviceArg = False) -> "RandomPolicy": # type: ignore return RandomPolicy(self)
@staticmethod def get_type() -> str: return "random_policy"
[docs]class RandomPolicy(QLearningAlgoBase[None, RandomPolicyConfig]): # type: ignore _action_size: int def __init__(self, config: RandomPolicyConfig): super().__init__(config, False, None) self._action_size = 1 def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: self._action_size = action_size
[docs] def predict(self, x: Observation) -> NDArray: return self.sample_action(x)
[docs] def sample_action(self, x: Observation) -> NDArray: x = np.asarray(x) action_shape = (x.shape[0], self._action_size) if self._config.distribution == "uniform": action = np.random.uniform(-1.0, 1.0, size=action_shape) elif self._config.distribution == "normal": action = np.random.normal( 0.0, self._config.normal_std, size=action_shape ) else: raise ValueError( f"invalid distribution type: {self._config.distribution}" ) action = np.clip(action, -1.0, 1.0) if self._config.action_scaler: action = self._config.action_scaler.reverse_transform_numpy(action) return action
[docs] def predict_value(self, x: Observation, action: NDArray) -> NDArray: raise NotImplementedError
def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: raise NotImplementedError
[docs] def get_action_type(self) -> ActionSpace: return ActionSpace.CONTINUOUS
[docs]@dataclasses.dataclass() class DiscreteRandomPolicyConfig(LearnableConfig): r"""Random Policy for discrete control algorithm. This is designed for data collection and lightweight interaction tests. ``fit`` and ``fit_online`` methods will raise exceptions. """
[docs] def create(self, device: DeviceArg = False) -> "DiscreteRandomPolicy": # type: ignore return DiscreteRandomPolicy(self)
@staticmethod def get_type() -> str: return "discrete_random_policy"
[docs]class DiscreteRandomPolicy(QLearningAlgoBase[None, DiscreteRandomPolicyConfig]): # type: ignore _action_size: int def __init__(self, config: DiscreteRandomPolicyConfig): super().__init__(config, False, None) self._action_size = 1 def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: self._action_size = action_size
[docs] def predict(self, x: Observation) -> NDArray: return self.sample_action(x)
[docs] def sample_action(self, x: Observation) -> NDArray: x = np.asarray(x) return np.random.randint(self._action_size, size=x.shape[0])
[docs] def predict_value(self, x: Observation, action: NDArray) -> NDArray: raise NotImplementedError
def inner_update(self, batch: TorchMiniBatch) -> Dict[str, float]: raise NotImplementedError
[docs] def get_action_type(self) -> ActionSpace: return ActionSpace.DISCRETE
register_learnable(RandomPolicyConfig) register_learnable(DiscreteRandomPolicyConfig)