Source code for d3rlpy.logging.wandb_adapter

from typing import Any, Optional

from .logger import (
    AlgProtocol,
    LoggerAdapter,
    LoggerAdapterFactory,
    SaveProtocol,
)

__all__ = ["WanDBAdapter", "WanDBAdapterFactory"]


[docs]class WanDBAdapter(LoggerAdapter): r"""WandB Logger Adapter class. This class logs data to Weights & Biases (WandB) for experiment tracking. Args: algo: Algorithm. experiment_name (str): Name of the experiment. n_steps_per_epoch: Number of steps per epoch. project: Project name. """ def __init__( self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int, project: Optional[str] = None, ): try: import wandb except ImportError as e: raise ImportError("Please install wandb") from e assert algo.impl self.run = wandb.init(project=project, name=experiment_name) self.run.watch( tuple(algo.impl.modules.get_torch_modules().values()), log="gradients", log_freq=n_steps_per_epoch, ) self._is_model_watched = False
[docs] def write_params(self, params: dict[str, Any]) -> None: """Writes hyperparameters to WandB config.""" self.run.config.update(params)
[docs] def before_write_metric(self, epoch: int, step: int) -> None: """Callback executed before writing metric."""
[docs] def write_metric( self, epoch: int, step: int, name: str, value: float ) -> None: """Writes metric to WandB.""" self.run.log({name: value, "epoch": epoch}, step=step)
[docs] def after_write_metric(self, epoch: int, step: int) -> None: """Callback executed after writing metric."""
[docs] def save_model(self, epoch: int, algo: SaveProtocol) -> None: """Saves models to Weights & Biases. Not implemented for WandB. """
# Implement saving model to wandb if needed
[docs] def close(self) -> None: """Closes the logger and finishes the WandB run.""" self.run.finish()
[docs] def watch_model( self, epoch: int, step: int, ) -> None: pass
[docs]class WanDBAdapterFactory(LoggerAdapterFactory): r"""WandB Logger Adapter Factory class. This class creates instances of the WandB Logger Adapter for experiment tracking. Args: project (Optional[str], optional): The name of the WandB project. Defaults to None. """ _project: Optional[str] def __init__(self, project: Optional[str] = None) -> None: self._project = project
[docs] def create( self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int ) -> LoggerAdapter: return WanDBAdapter( algo=algo, experiment_name=experiment_name, n_steps_per_epoch=n_steps_per_epoch, project=self._project, )