from typing import Any, Dict, Sequence
from .logger import LoggerAdapter, LoggerAdapterFactory, SaveProtocol
__all__ = ["CombineAdapter", "CombineAdapterFactory"]
[docs]class CombineAdapter(LoggerAdapter):
r"""CombineAdapter class.
This class combines multiple LoggerAdapter to write metrics through
different adapters at the same time.
Args:
adapters (Sequence[LoggerAdapter]): List of LoggerAdapter.
"""
def __init__(self, adapters: Sequence[LoggerAdapter]):
self._adapters = adapters
[docs] def write_params(self, params: Dict[str, Any]) -> None:
for adapter in self._adapters:
adapter.write_params(params)
[docs] def before_write_metric(self, epoch: int, step: int) -> None:
for adapter in self._adapters:
adapter.before_write_metric(epoch, step)
[docs] def write_metric(
self, epoch: int, step: int, name: str, value: float
) -> None:
for adapter in self._adapters:
adapter.write_metric(epoch, step, name, value)
[docs] def after_write_metric(self, epoch: int, step: int) -> None:
for adapter in self._adapters:
adapter.after_write_metric(epoch, step)
[docs] def save_model(self, epoch: int, algo: SaveProtocol) -> None:
for adapter in self._adapters:
adapter.save_model(epoch, algo)
[docs] def close(self) -> None:
for adapter in self._adapters:
adapter.close()
[docs]class CombineAdapterFactory(LoggerAdapterFactory):
r"""CombineAdapterFactory class.
This class instantiates ``CombineAdapter`` object.
Args:
adapter_factories (Sequence[LoggerAdapterFactory]):
List of LoggerAdapterFactory.
"""
_adapter_factories: Sequence[LoggerAdapterFactory]
def __init__(self, adapter_factories: Sequence[LoggerAdapterFactory]):
self._adapter_factories = adapter_factories
[docs] def create(self, experiment_name: str) -> CombineAdapter:
return CombineAdapter(
[
factory.create(experiment_name)
for factory in self._adapter_factories
]
)