Source code for d3rlpy.logging.file_adapter

import json
import os
from enum import Enum, IntEnum
from typing import Any, Dict

import numpy as np

from .logger import LOG, LoggerAdapter, LoggerAdapterFactory, SaveProtocol

__all__ = ["FileAdapter", "FileAdapterFactory"]


# default json encoder for numpy objects
def default_json_encoder(obj: Any) -> Any:
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (Enum, IntEnum)):
        return obj.value
    raise ValueError(f"invalid object type: {type(obj)}")


[docs]class FileAdapter(LoggerAdapter): r"""FileAdapter class. This class saves metrics as CSV files, hyperparameters as json file and models as d3 files. Args: logdir (str): Log directory. """ _logdir: str def __init__(self, logdir: str): self._logdir = logdir if not os.path.exists(self._logdir): os.makedirs(self._logdir) LOG.info(f"Directory is created at {self._logdir}")
[docs] def write_params(self, params: Dict[str, Any]) -> None: # save dictionary as json file params_path = os.path.join(self._logdir, "params.json") with open(params_path, "w") as f: json_str = json.dumps( params, default=default_json_encoder, indent=2 ) f.write(json_str)
[docs] def before_write_metric(self, epoch: int, step: int) -> None: pass
[docs] def write_metric( self, epoch: int, step: int, name: str, value: float ) -> None: path = os.path.join(self._logdir, f"{name}.csv") with open(path, "a") as f: print(f"{epoch},{step},{value}", file=f)
[docs] def after_write_metric(self, epoch: int, step: int) -> None: pass
[docs] def save_model(self, epoch: int, algo: SaveProtocol) -> None: # save entire model model_path = os.path.join(self._logdir, f"model_{epoch}.d3") algo.save(model_path) LOG.info(f"Model parameters are saved to {model_path}")
[docs] def close(self) -> None: pass
@property def logdir(self) -> str: return self._logdir
[docs]class FileAdapterFactory(LoggerAdapterFactory): r"""FileAdapterFactory class. This class instantiates ``FileAdapter`` object. Log directory will be created at ``<root_dir>/<experiment_name>``. Args: root_dir (str): Top-level log directory. """ _root_dir: str def __init__(self, root_dir: str = "d3rlpy_logs"): self._root_dir = root_dir
[docs] def create(self, experiment_name: str) -> FileAdapter: logdir = os.path.join(self._root_dir, experiment_name) return FileAdapter(logdir)