Source code for d3rlpy.optimizers.lr_schedulers

import dataclasses

from torch.optim import Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler

from ..serializable_config import (
    DynamicConfig,
    generate_optional_config_generation,
)

__all__ = [
    "LRSchedulerFactory",
    "WarmupSchedulerFactory",
    "CosineAnnealingLRFactory",
    "make_lr_scheduler_field",
]


[docs]@dataclasses.dataclass() class LRSchedulerFactory(DynamicConfig): """A factory class that creates a learning rate scheduler a lazy way."""
[docs] def create(self, optim: Optimizer) -> LRScheduler: """Returns a learning rate scheduler object. Args: optim: PyTorch optimizer. Returns: Learning rate scheduler. """ raise NotImplementedError
[docs]@dataclasses.dataclass() class WarmupSchedulerFactory(LRSchedulerFactory): r"""A warmup learning rate scheduler. .. math:: lr = \max((t + 1) / warmup\_steps, 1) Args: warmup_steps: Warmup steps. """ warmup_steps: int
[docs] def create(self, optim: Optimizer) -> LRScheduler: return LambdaLR( optim, lambda steps: min((steps + 1) / self.warmup_steps, 1), )
[docs] @staticmethod def get_type() -> str: return "warmup"
[docs]@dataclasses.dataclass() class CosineAnnealingLRFactory(LRSchedulerFactory): """A cosine annealing learning rate scheduler. Args: T_max: Maximum time step. eta_min: Minimum learning rate. last_epoch: Last epoch. """ T_max: int eta_min: float = 0.0 last_epoch: int = -1
[docs] def create(self, optim: Optimizer) -> LRScheduler: return CosineAnnealingLR( optim, T_max=self.T_max, eta_min=self.eta_min, last_epoch=self.last_epoch, )
[docs] @staticmethod def get_type() -> str: return "cosine_annealing"
register_lr_scheduler_factory, make_lr_scheduler_field = ( generate_optional_config_generation( LRSchedulerFactory, ) ) register_lr_scheduler_factory(WarmupSchedulerFactory) register_lr_scheduler_factory(CosineAnnealingLRFactory)