Source code for d3rlpy.algos.transformer.action_samplers

from typing import Union

import numpy as np
from typing_extensions import Protocol

from ...types import NDArray

__all__ = [
    "TransformerActionSampler",
    "IdentityTransformerActionSampler",
    "SoftmaxTransformerActionSampler",
    "GreedyTransformerActionSampler",
]


[docs]class TransformerActionSampler(Protocol): r"""Interface of TransformerActionSampler."""
[docs] def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]: r"""Returns sampled action from Transformer output. Args: transformer_output: Output of Transformer algorithms. Returns: Sampled action. """ raise NotImplementedError
[docs]class IdentityTransformerActionSampler(TransformerActionSampler): r"""Identity action-sampler. This class implements identity function to process Transformer output. Sampled action is the exactly same as ``transformer_output``. """
[docs] def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]: return transformer_output
[docs]class SoftmaxTransformerActionSampler(TransformerActionSampler): r"""Softmax action-sampler. This class implements softmax function to sample action from discrete probability distribution. Args: temperature (int): Softmax temperature. """ _temperature: float def __init__(self, temperature: float = 1.0): self._temperature = temperature
[docs] def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]: assert transformer_output.ndim == 1 logits = transformer_output / self._temperature x = np.exp(logits - np.max(logits)) probs = x / np.sum(x) action = np.random.choice(probs.shape[0], p=probs) return int(action)
[docs]class GreedyTransformerActionSampler(TransformerActionSampler): r"""Greedy action-sampler. This class implements greedy function to determine action from discrte probability distribution. """
[docs] def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]: assert transformer_output.ndim == 1 return int(np.argmax(transformer_output))