Source code for d3rlpy.logging.utils

from typing import Any, Dict, Sequence

from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol

__all__ = ["CombineAdapter", "CombineAdapterFactory"]


[docs]class CombineAdapter(LoggerAdapter): r"""CombineAdapter class. This class combines multiple LoggerAdapter to write metrics through different adapters at the same time. Args: adapters (Sequence[LoggerAdapter]): List of LoggerAdapter. """ def __init__(self, adapters: Sequence[LoggerAdapter]): self._adapters = adapters
[docs] def write_params(self, params: Dict[str, Any]) -> None: for adapter in self._adapters: adapter.write_params(params)
[docs] def before_write_metric(self, epoch: int, step: int) -> None: for adapter in self._adapters: adapter.before_write_metric(epoch, step)
[docs] def write_metric( self, epoch: int, step: int, name: str, value: float ) -> None: for adapter in self._adapters: adapter.write_metric(epoch, step, name, value)
[docs] def after_write_metric(self, epoch: int, step: int) -> None: for adapter in self._adapters: adapter.after_write_metric(epoch, step)
[docs] def save_model(self, epoch: int, algo: SaveProtocol) -> None: for adapter in self._adapters: adapter.save_model(epoch, algo)
[docs] def close(self) -> None: for adapter in self._adapters: adapter.close()
[docs]class CombineAdapterFactory(LoggerAdapterFactory): r"""CombineAdapterFactory class. This class instantiates ``CombineAdapter`` object. Args: adapter_factories (Sequence[LoggerAdapterFactory]): List of LoggerAdapterFactory. """ _adapter_factories: Sequence[LoggerAdapterFactory] def __init__(self, adapter_factories: Sequence[LoggerAdapterFactory]): self._adapter_factories = adapter_factories
[docs] def create(self, experiment_name: str) -> CombineAdapter: return CombineAdapter( [ factory.create(experiment_name) for factory in self._adapter_factories ] )