Source code for d3rlpy.logging.logger

import time
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Iterator, Optional, Protocol

import structlog
from torch import nn

from ..types import Float32NDArray

__all__ = [
    "LOG",
    "set_log_context",
    "D3RLPyLogger",
    "LoggerAdapter",
    "LoggerAdapterFactory",
]

structlog.configure(
    processors=[
        structlog.contextvars.merge_contextvars,
        structlog.processors.add_log_level,
        structlog.processors.StackInfoRenderer(),
        structlog.dev.set_exc_info,
        structlog.processors.format_exc_info,
        structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M.%S", utc=False),
        structlog.dev.ConsoleRenderer(),
    ],
)


LOG: structlog.BoundLogger = structlog.get_logger(__name__)


def set_log_context(**kwargs: Any) -> None:
    structlog.contextvars.bind_contextvars(**kwargs)


class SaveProtocol(Protocol):
    def save(self, fname: str) -> None: ...


class ModuleProtocol(Protocol):
    def get_torch_modules(self) -> dict[str, nn.Module]: ...
    def get_gradients(self) -> Iterator[tuple[str, Float32NDArray]]: ...


class ImplProtocol(Protocol):
    @property
    def modules(self) -> ModuleProtocol:
        raise NotImplementedError


class AlgProtocol(Protocol):
    @property
    def impl(self) -> Optional[ImplProtocol]:
        raise NotImplementedError


[docs]class LoggerAdapter(Protocol): r"""Interface of LoggerAdapter."""
[docs] def write_params(self, params: dict[str, Any]) -> None: r"""Writes hyperparameters. Args: params: Dictionary of hyperparameters. """
[docs] def before_write_metric(self, epoch: int, step: int) -> None: r"""Callback executed before write_metric method. Args: epoch: Epoch. step: Training step. """
[docs] def write_metric( self, epoch: int, step: int, name: str, value: float ) -> None: r"""Writes metric. Args: epoch: Epoch. step: Training step. name: Metric name. value: Metric value. """
[docs] def after_write_metric(self, epoch: int, step: int) -> None: r"""Callback executed after write_metric method. Args: epoch: Epoch. step: Training step. """
[docs] def save_model(self, epoch: int, algo: SaveProtocol) -> None: r"""Saves models. Args: epoch: Epoch. algo: Algorithm that provides ``save`` method. """
[docs] def close(self) -> None: r"""Closes this LoggerAdapter."""
[docs] def watch_model( self, epoch: int, step: int, ) -> None: r"""Watch model parameters / gradients during training. Args: epoch: Epoch. step: Training step. """
[docs]class LoggerAdapterFactory(Protocol): r"""Interface of LoggerAdapterFactory."""
[docs] def create( self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int ) -> LoggerAdapter: r"""Creates LoggerAdapter. This method instantiates ``LoggerAdapter`` with a given ``experiment_name``. This method is usually called at the beginning of training. Args: algo: Algorithm. experiment_name: Experiment name. steps_per_epoch: Number of steps per epoch. """ raise NotImplementedError
class D3RLPyLogger: _algo: AlgProtocol _adapter: LoggerAdapter _experiment_name: str _metrics_buffer: defaultdict[str, list[float]] def __init__( self, algo: AlgProtocol, adapter_factory: LoggerAdapterFactory, experiment_name: str, n_steps_per_epoch: int, with_timestamp: bool = True, ): if with_timestamp: date = datetime.now().strftime("%Y%m%d%H%M%S") self._experiment_name = experiment_name + "_" + date else: self._experiment_name = experiment_name self._algo = algo self._adapter = adapter_factory.create( algo, self._experiment_name, n_steps_per_epoch ) self._metrics_buffer = defaultdict(list) def add_params(self, params: dict[str, Any]) -> None: self._adapter.write_params(params) LOG.info("Parameters", params=params) def add_metric(self, name: str, value: float) -> None: self._metrics_buffer[name].append(value) def commit(self, epoch: int, step: int) -> dict[str, float]: self._adapter.before_write_metric(epoch, step) metrics = {} for name, buffer in self._metrics_buffer.items(): metric = sum(buffer) / len(buffer) self._adapter.write_metric(epoch, step, name, metric) metrics[name] = metric LOG.info( f"{self._experiment_name}: epoch={epoch} step={step}", epoch=epoch, step=step, metrics=metrics, ) self._adapter.after_write_metric(epoch, step) # save model parameter metrics self._adapter.watch_model(epoch, step) # initialize metrics buffer self._metrics_buffer.clear() return metrics def save_model(self, epoch: int, algo: SaveProtocol) -> None: self._adapter.save_model(epoch, algo) def close(self) -> None: self._adapter.close() @contextmanager def measure_time(self, name: str) -> Iterator[None]: name = "time_" + name start = time.time() try: yield finally: self.add_metric(name, time.time() - start) @property def adapter(self) -> LoggerAdapter: return self._adapter