import copy
from typing import Any, Dict, Iterable, Tuple, Type, Union, cast
import torch.nn as nn
import torch.optim as optim
from torch.optim import SGD, Adam, Optimizer, RMSprop
from ..decorators import pretty_repr
[docs]@pretty_repr
class OptimizerFactory:
"""A factory class that creates an optimizer object in a lazy way.
The optimizers in algorithms can be configured through this factory class.
.. code-block:: python
from torch.optim Adam
from d3rlpy.optimizers import OptimizerFactory
from d3rlpy.algos import DQN
factory = OptimizerFactory(Adam, eps=0.001)
dqn = DQN(optim_factory=factory)
Args:
optim_cls: An optimizer class.
kwargs: arbitrary keyword-arguments.
"""
_optim_cls: Type[Optimizer]
_optim_kwargs: Dict[str, Any]
def __init__(self, optim_cls: Union[Type[Optimizer], str], **kwargs: Any):
if isinstance(optim_cls, str):
self._optim_cls = cast(Type[Optimizer], getattr(optim, optim_cls))
else:
self._optim_cls = optim_cls
self._optim_kwargs = kwargs
[docs] def create(self, params: Iterable[nn.Parameter], lr: float) -> Optimizer:
"""Returns an optimizer object.
Args:
params (list): a list of PyTorch parameters.
lr (float): learning rate.
Returns:
torch.optim.Optimizer: an optimizer object.
"""
return self._optim_cls(params, lr=lr, **self._optim_kwargs)
[docs] def get_params(self, deep: bool = False) -> Dict[str, Any]:
"""Returns optimizer parameters.
Args:
deep: flag to deeply copy the parameters.
Returns:
optimizer parameters.
"""
if deep:
params = copy.deepcopy(self._optim_kwargs)
else:
params = self._optim_kwargs
return {"optim_cls": self._optim_cls.__name__, **params}
[docs]class SGDFactory(OptimizerFactory):
"""An alias for SGD optimizer.
.. code-block:: python
from d3rlpy.optimizers import SGDFactory
factory = SGDFactory(weight_decay=1e-4)
Args:
momentum: momentum factor.
dampening: dampening for momentum.
weight_decay: weight decay (L2 penalty).
nesterov: flag to enable Nesterov momentum.
"""
def __init__(
self,
momentum: float = 0,
dampening: float = 0,
weight_decay: float = 0,
nesterov: bool = False,
**kwargs: Any
):
super().__init__(
optim_cls=SGD,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
[docs]class AdamFactory(OptimizerFactory):
"""An alias for Adam optimizer.
.. code-block:: python
from d3rlpy.optimizers import AdamFactory
factory = AdamFactory(weight_decay=1e-4)
Args:
betas: coefficients used for computing running averages of
gradient and its square.
eps: term added to the denominator to improve numerical stability.
weight_decay: weight decay (L2 penalty).
amsgrad: flag to use the AMSGrad variant of this algorithm.
"""
def __init__(
self,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
amsgrad: bool = False,
**kwargs: Any
):
super().__init__(
optim_cls=Adam,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
)
[docs]class RMSpropFactory(OptimizerFactory):
"""An alias for RMSprop optimizer.
.. code-block:: python
from d3rlpy.optimizers import RMSpropFactory
factory = RMSpropFactory(weight_decay=1e-4)
Args:
alpha: smoothing constant.
eps: term added to the denominator to improve numerical stability.
weight_decay: weight decay (L2 penalty).
momentum: momentum factor.
centered: flag to compute the centered RMSProp, the gradient is
normalized by an estimation of its variance.
"""
def __init__(
self,
alpha: float = 0.95,
eps: float = 1e-2,
weight_decay: float = 0,
momentum: float = 0,
centered: bool = True,
**kwargs: Any
):
super().__init__(
optim_cls=RMSprop,
alpha=alpha,
eps=eps,
weight_decay=weight_decay,
momentum=momentum,
centered=centered,
)