import time
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from typing import Any, DefaultDict, Dict, Iterator, List
import structlog
from typing_extensions import Protocol
__all__ = [
"LOG",
"set_log_context",
"D3RLPyLogger",
"LoggerAdapter",
"LoggerAdapterFactory",
]
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.StackInfoRenderer(),
structlog.dev.set_exc_info,
structlog.processors.format_exc_info,
structlog.processors.TimeStamper(fmt="%Y-%m-%d %H:%M.%S", utc=False),
structlog.dev.ConsoleRenderer(),
],
)
LOG: structlog.BoundLogger = structlog.get_logger(__name__)
def set_log_context(**kwargs: Any) -> None:
structlog.contextvars.bind_contextvars(**kwargs)
class SaveProtocol(Protocol):
def save(self, fname: str) -> None: ...
[docs]class LoggerAdapter(Protocol):
r"""Interface of LoggerAdapter."""
[docs] def write_params(self, params: Dict[str, Any]) -> None:
r"""Writes hyperparameters.
Args:
params: Dictionary of hyperparameters.
"""
[docs] def before_write_metric(self, epoch: int, step: int) -> None:
r"""Callback executed before write_metric method.
Args:
epoch: Epoch.
step: Training step.
"""
[docs] def write_metric(
self, epoch: int, step: int, name: str, value: float
) -> None:
r"""Writes metric.
Args:
epoch: Epoch.
step: Training step.
name: Metric name.
value: Metric value.
"""
[docs] def after_write_metric(self, epoch: int, step: int) -> None:
r"""Callback executed after write_metric method.
Args:
epoch: Epoch.
step: Training step.
"""
[docs] def save_model(self, epoch: int, algo: SaveProtocol) -> None:
r"""Saves models.
Args:
epoch: Epoch.
algo: Algorithm that provides ``save`` method.
"""
[docs] def close(self) -> None:
r"""Closes this LoggerAdapter."""
[docs]class LoggerAdapterFactory(Protocol):
r"""Interface of LoggerAdapterFactory."""
[docs] def create(self, experiment_name: str) -> LoggerAdapter:
r"""Creates LoggerAdapter.
This method instantiates ``LoggerAdapter`` with a given
``experiment_name``.
This method is usually called at the beginning of training.
Args:
experiment_name: Experiment name.
"""
raise NotImplementedError
class D3RLPyLogger:
_adapter: LoggerAdapter
_experiment_name: str
_metrics_buffer: DefaultDict[str, List[float]]
def __init__(
self,
adapter_factory: LoggerAdapterFactory,
experiment_name: str,
with_timestamp: bool = True,
):
if with_timestamp:
date = datetime.now().strftime("%Y%m%d%H%M%S")
self._experiment_name = experiment_name + "_" + date
else:
self._experiment_name = experiment_name
self._adapter = adapter_factory.create(self._experiment_name)
self._metrics_buffer = defaultdict(list)
def add_params(self, params: Dict[str, Any]) -> None:
self._adapter.write_params(params)
LOG.info("Parameters", params=params)
def add_metric(self, name: str, value: float) -> None:
self._metrics_buffer[name].append(value)
def commit(self, epoch: int, step: int) -> Dict[str, float]:
self._adapter.before_write_metric(epoch, step)
metrics = {}
for name, buffer in self._metrics_buffer.items():
metric = sum(buffer) / len(buffer)
self._adapter.write_metric(epoch, step, name, metric)
metrics[name] = metric
LOG.info(
f"{self._experiment_name}: epoch={epoch} step={step}",
epoch=epoch,
step=step,
metrics=metrics,
)
self._adapter.after_write_metric(epoch, step)
# initialize metrics buffer
self._metrics_buffer.clear()
return metrics
def save_model(self, epoch: int, algo: SaveProtocol) -> None:
self._adapter.save_model(epoch, algo)
def close(self) -> None:
self._adapter.close()
@contextmanager
def measure_time(self, name: str) -> Iterator[None]:
name = "time_" + name
start = time.time()
try:
yield
finally:
self.add_metric(name, time.time() - start)
@property
def adapter(self) -> LoggerAdapter:
return self._adapter