Source code for d3rlpy.optimizers

import copy
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

from torch.optim import SGD, Adam, RMSprop


[docs]class OptimizerFactory: """ A factory class that creates an optimizer object in a lazy way. The optimizers in algorithms can be configured through this factory class. .. code-block:: python from torch.optim Adam from d3rlpy.optimizers import OptimizerFactory from d3rlpy.algos import DQN factory = OptimizerFactory(Adam, eps=0.001) dqn = DQN(optim_factory=factory) Args: optim_cls (type or str): An optimizer class. kwargs (any): arbitrary keyword-arguments. Attributes: optim_cls (type): An optimizer class. optim_kwargs (dict): given parameters for an optimizer. """ def __init__(self, optim_cls, **kwargs): if isinstance(optim_cls, str): self.optim_cls = getattr(optim, optim_cls) else: self.optim_cls = optim_cls self.optim_kwargs = kwargs
[docs] def create(self, params, lr): """ Returns an optimizer object. Args: params (list): a list of PyTorch parameters. lr (float): learning rate. Returns: torch.optim.Optimizer: an optimizer object. """ return self.optim_cls(params, lr=lr, **self.optim_kwargs)
[docs] def get_params(self, deep=False): """ Returns optimizer parameters. Args: deep (bool): flag to deeply copy the parameters. Returns: dict: optimizer parameters. """ if deep: params = copy.deepcopy(self.optim_kwargs) else: params = self.optim_kwargs return {'optim_cls': self.optim_cls.__name__, **params}
[docs]class SGDFactory(OptimizerFactory): """ An alias for SGD optimizer. .. code-block:: python from d3rlpy.optimizers import SGDFactory factory = SGDFactory(weight_decay=1e-4) Args: momentum (float): momentum factor. dampening (float): dampening for momentum. weight_decay (float): weight decay (L2 penalty). nesterov (bool): flag to enable Nesterov momentum. Attributes: optim_cls (type): ``torch.optim.SGD`` class. optim_kwargs (dict): given parameters for an optimizer. """ def __init__(self, momentum=0, dampening=0, weight_decay=0, nesterov=False, **kwargs): super().__init__(optim_cls=SGD, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov)
[docs]class AdamFactory(OptimizerFactory): """ An alias for Adam optimizer. .. code-block:: python from d3rlpy.optimizers import AdamFactory factory = AdamFactory(weight_decay=1e-4) Args: betas (tuple): coefficients used for computing running averages of gradient and its square. eps (float): term added to the denominator to improve numerical stability. weight_decay (float): weight decay (L2 penalty). amsgrad (bool): flag to use the AMSGrad variant of this algorithm. Attributes: optim_cls (type): ``torch.optim.Adam`` class. optim_kwargs (dict): given parameters for an optimizer. """ def __init__(self, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, **kwargs): super().__init__(optim_cls=Adam, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
[docs]class RMSpropFactory(OptimizerFactory): """ An alias for RMSprop optimizer. .. code-block:: python from d3rlpy.optimizers import RMSpropFactory factory = RMSpropFactory(weight_decay=1e-4) Args: alpha (float): smoothing constant. eps (float): term added to the denominator to improve numerical stability. weight_decay (float): weight decay (L2 penalty). momentum (float): momentum factor. centered (bool): flag to compute the centered RMSProp, the gradient is normalized by an estimation of its variance. Attributes: optim_cls (type): ``torch.optim.RMSprop`` class. optim_kwargs (dict): given parameters for an optimizer. """ def __init__(self, alpha=0.95, eps=1e-2, weight_decay=0, momentum=0, centered=True, **kwargs): super().__init__(optim_cls=RMSprop, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum, centered=centered)