from abc import abstractmethod
from collections import defaultdict
from typing import (
Callable,
Dict,
Generator,
Generic,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
import numpy as np
import torch
from torch import nn
from tqdm.auto import tqdm, trange
from typing_extensions import Self
from ...base import ImplBase, LearnableBase, LearnableConfig, save_config
from ...constants import (
IMPL_NOT_INITIALIZED_ERROR,
ActionSpace,
LoggingStrategy,
)
from ...dataset import (
ReplayBufferBase,
TransitionMiniBatch,
check_non_1d_array,
create_fifo_replay_buffer,
is_tuple_shape,
)
from ...logging import (
LOG,
D3RLPyLogger,
FileAdapterFactory,
LoggerAdapterFactory,
)
from ...metrics import EvaluatorProtocol, evaluate_qlearning_with_environment
from ...models.torch import Policy
from ...torch_utility import (
TorchMiniBatch,
convert_to_torch,
convert_to_torch_recursively,
eval_api,
hard_sync,
sync_optimizer_state,
train_api,
)
from ...types import GymEnv, NDArray, Observation, TorchObservation
from ..utility import (
assert_action_space_with_dataset,
assert_action_space_with_env,
build_scalers_with_env,
build_scalers_with_transition_picker,
)
from .explorers import Explorer
__all__ = [
"QLearningAlgoImplBase",
"QLearningAlgoBase",
"TQLearningImpl",
"TQLearningConfig",
]
class QLearningAlgoImplBase(ImplBase):
@train_api
def update(self, batch: TorchMiniBatch, grad_step: int) -> Dict[str, float]:
return self.inner_update(batch, grad_step)
@abstractmethod
def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
pass
@eval_api
def predict_best_action(self, x: TorchObservation) -> torch.Tensor:
return self.inner_predict_best_action(x)
@abstractmethod
def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
pass
@eval_api
def sample_action(self, x: TorchObservation) -> torch.Tensor:
return self.inner_sample_action(x)
@abstractmethod
def inner_sample_action(self, x: TorchObservation) -> torch.Tensor:
pass
@eval_api
def predict_value(
self, x: TorchObservation, action: torch.Tensor
) -> torch.Tensor:
return self.inner_predict_value(x, action)
@abstractmethod
def inner_predict_value(
self, x: TorchObservation, action: torch.Tensor
) -> torch.Tensor:
pass
@property
def policy(self) -> Policy:
raise NotImplementedError
def copy_policy_from(self, impl: "QLearningAlgoImplBase") -> None:
if not isinstance(impl.policy, type(self.policy)):
raise ValueError(
f"Invalid policy type: expected={type(self.policy)},"
f"actual={type(impl.policy)}"
)
hard_sync(self.policy, impl.policy)
@property
def policy_optim(self) -> torch.optim.Optimizer:
raise NotImplementedError
def copy_policy_optim_from(self, impl: "QLearningAlgoImplBase") -> None:
if not isinstance(impl.policy_optim, type(self.policy_optim)):
raise ValueError(
"Invalid policy optimizer type: "
f"expected={type(self.policy_optim)},"
f"actual={type(impl.policy_optim)}"
)
sync_optimizer_state(self.policy_optim, impl.policy_optim)
@property
def q_function(self) -> nn.ModuleList:
raise NotImplementedError
def copy_q_function_from(self, impl: "QLearningAlgoImplBase") -> None:
q_func = self.q_function[0]
if not isinstance(impl.q_function[0], type(q_func)):
raise ValueError(
f"Invalid Q-function type: expected={type(q_func)},"
f"actual={type(impl.q_function[0])}"
)
hard_sync(self.q_function, impl.q_function)
@property
def q_function_optim(self) -> torch.optim.Optimizer:
raise NotImplementedError
def copy_q_function_optim_from(self, impl: "QLearningAlgoImplBase") -> None:
if not isinstance(impl.q_function_optim, type(self.q_function_optim)):
raise ValueError(
"Invalid Q-function optimizer type: "
f"expected={type(self.q_function_optim)}",
f"actual={type(impl.q_function_optim)}",
)
sync_optimizer_state(self.q_function_optim, impl.q_function_optim)
def reset_optimizer_states(self) -> None:
self.modules.reset_optimizer_states()
TQLearningImpl = TypeVar("TQLearningImpl", bound=QLearningAlgoImplBase)
TQLearningConfig = TypeVar("TQLearningConfig", bound=LearnableConfig)
[docs]class QLearningAlgoBase(
Generic[TQLearningImpl, TQLearningConfig],
LearnableBase[TQLearningImpl, TQLearningConfig],
):
[docs] def save_policy(self, fname: str) -> None:
"""Save the greedy-policy computational graph as TorchScript or ONNX.
The format will be automatically detected by the file name.
.. code-block:: python
# save as TorchScript
algo.save_policy('policy.pt')
# save as ONNX
algo.save_policy('policy.onnx')
The artifacts saved with this method will work without d3rlpy.
This method is especially useful to deploy the learned policy to
production environments or embedding systems.
See also
* https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html (for Python).
* https://pytorch.org/tutorials/advanced/cpp_export.html (for C++).
* https://onnx.ai (for ONNX)
Visit https://d3rlpy.readthedocs.io/en/stable/tutorials/after_training_policies.html#export-policies-as-torchscript for the further usage.
Args:
fname: Destination file path.
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
if is_tuple_shape(self._impl.observation_shape):
dummy_x = [
torch.rand(1, *shape, device=self._device)
for shape in self._impl.observation_shape
]
num_inputs = len(self._impl.observation_shape)
else:
dummy_x = torch.rand(
1, *self._impl.observation_shape, device=self._device
)
num_inputs = 1
# workaround until version 1.6
self._impl.modules.freeze()
# local function to select best actions
def _func(*x: Sequence[torch.Tensor]) -> torch.Tensor:
assert self._impl
observation: TorchObservation = x
if len(observation) == 1:
observation = observation[0]
if self._config.observation_scaler:
observation = self._config.observation_scaler.transform(
observation
)
action = self._impl.predict_best_action(observation)
if self._config.action_scaler:
action = self._config.action_scaler.reverse_transform(action)
return action
traced_script = torch.jit.trace(_func, dummy_x, check_trace=False)
if fname.endswith(".onnx"):
# currently, PyTorch cannot directly export function as ONNX.
torch.onnx.export(
traced_script,
dummy_x,
fname,
export_params=True,
opset_version=11,
input_names=[f"input_{i}" for i in range(num_inputs)],
output_names=["output_0"],
)
elif fname.endswith(".pt"):
traced_script.save(fname)
else:
raise ValueError(
f"invalid format type: {fname}."
" .pt and .onnx extensions are currently supported."
)
# workaround until version 1.6
self._impl.modules.unfreeze()
[docs] def predict(self, x: Observation) -> NDArray:
"""Returns greedy actions.
.. code-block:: python
# 100 observations with shape of (10,)
x = np.random.random((100, 10))
actions = algo.predict(x)
# actions.shape == (100, action size) for continuous control
# actions.shape == (100,) for discrete control
Args:
x: Observations
Returns:
Greedy actions
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
assert check_non_1d_array(x), "Input must have batch dimension."
torch_x = convert_to_torch_recursively(x, self._device)
with torch.no_grad():
if self._config.observation_scaler:
torch_x = self._config.observation_scaler.transform(torch_x)
action = self._impl.predict_best_action(torch_x)
if self._config.action_scaler:
action = self._config.action_scaler.reverse_transform(action)
return action.cpu().detach().numpy() # type: ignore
[docs] def predict_value(self, x: Observation, action: NDArray) -> NDArray:
"""Returns predicted action-values.
.. code-block:: python
# 100 observations with shape of (10,)
x = np.random.random((100, 10))
# for continuous control
# 100 actions with shape of (2,)
actions = np.random.random((100, 2))
# for discrete control
# 100 actions in integer values
actions = np.random.randint(2, size=100)
values = algo.predict_value(x, actions)
# values.shape == (100,)
Args:
x: Observations
action: Actions
Returns:
Predicted action-values
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
assert check_non_1d_array(x), "Input must have batch dimension."
torch_x = convert_to_torch_recursively(x, self._device)
torch_action = convert_to_torch(action, self._device)
with torch.no_grad():
if self._config.observation_scaler:
torch_x = self._config.observation_scaler.transform(torch_x)
if self.get_action_type() == ActionSpace.CONTINUOUS:
if self._config.action_scaler:
torch_action = self._config.action_scaler.transform(
torch_action
)
elif self.get_action_type() == ActionSpace.DISCRETE:
torch_action = torch_action.long()
else:
raise ValueError("invalid action type")
value = self._impl.predict_value(torch_x, torch_action)
return value.cpu().detach().numpy() # type: ignore
[docs] def sample_action(self, x: Observation) -> NDArray:
"""Returns sampled actions.
The sampled actions are identical to the output of `predict` method if
the policy is deterministic.
Args:
x: Observations.
Returns:
Sampled actions.
"""
assert self._impl is not None, IMPL_NOT_INITIALIZED_ERROR
assert check_non_1d_array(x), "Input must have batch dimension."
torch_x = convert_to_torch_recursively(x, self._device)
with torch.no_grad():
if self._config.observation_scaler:
torch_x = self._config.observation_scaler.transform(torch_x)
action = self._impl.sample_action(torch_x)
# transform action back to the original range
if self._config.action_scaler:
action = self._config.action_scaler.reverse_transform(action)
return action.cpu().detach().numpy() # type: ignore
[docs] def fit(
self,
dataset: ReplayBufferBase,
n_steps: int,
n_steps_per_epoch: int = 10000,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
enable_ddp: bool = False,
) -> List[Tuple[int, Dict[str, float]]]:
"""Trains with given dataset.
.. code-block:: python
algo.fit(episodes, n_steps=1000000)
Args:
dataset: ReplayBuffer object.
n_steps: Number of steps to train.
n_steps_per_epoch: Number of steps per epoch. This value will
be ignored when ``n_steps`` is ``None``.
experiment_name: Experiment name for logging. If not passed,
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if logging_strategy is EPOCH.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
evaluators: List of evaluators.
callback: Callable function that takes ``(algo, epoch, total_step)``
, which is called every step.
epoch_callback: Callable function that takes
``(algo, epoch, total_step)``, which is called at the end of
every epoch.
enable_ddp: Flag to wrap models with DataDistributedParallel.
Returns:
List of result tuples (epoch, metrics) per epoch.
"""
results = list(
self.fitter(
dataset=dataset,
n_steps=n_steps,
n_steps_per_epoch=n_steps_per_epoch,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logging_steps=logging_steps,
logging_strategy=logging_strategy,
logger_adapter=logger_adapter,
show_progress=show_progress,
save_interval=save_interval,
evaluators=evaluators,
callback=callback,
epoch_callback=epoch_callback,
enable_ddp=enable_ddp,
)
)
return results
[docs] def fitter(
self,
dataset: ReplayBufferBase,
n_steps: int,
n_steps_per_epoch: int = 10000,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
save_interval: int = 1,
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
enable_ddp: bool = False,
) -> Generator[Tuple[int, Dict[str, float]], None, None]:
"""Iterate over epochs steps to train with the given dataset. At each
iteration algo methods and properties can be changed or queried.
.. code-block:: python
for epoch, metrics in algo.fitter(episodes):
my_plot(metrics)
algo.save_model(my_path)
Args:
dataset: Offline dataset to train.
n_steps: Number of steps to train.
n_steps_per_epoch: Number of steps per epoch. This value will
be ignored when ``n_steps`` is ``None``.
experiment_name: Experiment name for logging. If not passed,
the directory name will be `{class name}_{timestamp}`.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if loggig_strategy is EPOCH.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
save_interval: Interval to save parameters.
evaluators: List of evaluators.
callback: Callable function that takes ``(algo, epoch, total_step)``
, which is called every step.
epoch_callback: Callable function that takes
``(algo, epoch, total_step)``, which is called at the end of
every epoch.
enable_ddp: Flag to wrap models with DataDistributedParallel.
Returns:
Iterator yielding current epoch and metrics dict.
"""
LOG.info("dataset info", dataset_info=dataset.dataset_info)
# check action space
assert_action_space_with_dataset(self, dataset.dataset_info)
# initialize scalers
build_scalers_with_transition_picker(self, dataset)
# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__
logger = D3RLPyLogger(
adapter_factory=logger_adapter,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
)
# instantiate implementation
if self._impl is None:
LOG.debug("Building models...")
action_size = dataset.dataset_info.action_size
observation_shape = (
dataset.sample_transition().observation_signature.shape
)
if len(observation_shape) == 1:
observation_shape = observation_shape[0] # type: ignore
self.create_impl(observation_shape, action_size)
LOG.debug("Models have been built.")
else:
LOG.warning("Skip building models since they're already built.")
# wrap all PyTorch modules with DataDistributedParallel
if enable_ddp:
assert self._impl
self._impl.wrap_models_by_ddp()
# save hyperparameters
save_config(self, logger)
# training loop
n_epochs = n_steps // n_steps_per_epoch
total_step = 0
for epoch in range(1, n_epochs + 1):
# dict to add incremental mean losses to epoch
epoch_loss = defaultdict(list)
range_gen = tqdm(
range(n_steps_per_epoch),
disable=not show_progress,
desc=f"Epoch {int(epoch)}/{n_epochs}",
)
for itr in range_gen:
with logger.measure_time("step"):
# pick transitions
with logger.measure_time("sample_batch"):
batch = dataset.sample_transition_batch(
self._config.batch_size
)
# update parameters
with logger.measure_time("algorithm_update"):
loss = self.update(batch)
# record metrics
for name, val in loss.items():
logger.add_metric(name, val)
epoch_loss[name].append(val)
# update progress postfix with losses
if itr % 10 == 0:
mean_loss = {
k: np.mean(v) for k, v in epoch_loss.items()
}
range_gen.set_postfix(mean_loss)
total_step += 1
if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
):
metrics = logger.commit(epoch, total_step)
# call callback if given
if callback:
callback(self, epoch, total_step)
# call epoch_callback if given
if epoch_callback:
epoch_callback(self, epoch, total_step)
if evaluators:
for name, evaluator in evaluators.items():
test_score = evaluator(self, dataset)
logger.add_metric(name, test_score)
# save metrics
if logging_strategy == LoggingStrategy.EPOCH:
metrics = logger.commit(epoch, total_step)
# save model parameters
if epoch % save_interval == 0:
logger.save_model(total_step, self)
yield epoch, metrics
logger.close()
[docs] def fit_online(
self,
env: GymEnv,
buffer: Optional[ReplayBufferBase] = None,
explorer: Optional[Explorer] = None,
n_steps: int = 1000000,
n_steps_per_epoch: int = 10000,
update_interval: int = 1,
n_updates: int = 1,
update_start_step: int = 0,
random_steps: int = 0,
eval_env: Optional[GymEnv] = None,
eval_epsilon: float = 0.0,
save_interval: int = 1,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logging_steps: int = 500,
logging_strategy: LoggingStrategy = LoggingStrategy.EPOCH,
logger_adapter: LoggerAdapterFactory = FileAdapterFactory(),
show_progress: bool = True,
callback: Optional[Callable[[Self, int, int], None]] = None,
) -> None:
"""Start training loop of online deep reinforcement learning.
Args:
env: Gym-like environment.
buffer : Replay buffer.
explorer: Action explorer.
n_steps: Number of total steps to train.
n_steps_per_epoch: Number of steps per epoch.
update_interval: Number of steps per update.
n_updates: Number of gradient steps at a time. The combination of
``update_interval`` and ``n_updates`` controls Update-To-Data
(UTD) ratio.
update_start_step: Steps before starting updates.
random_steps: Steps for the initial random explortion.
eval_env: Gym-like environment. If None, evaluation is skipped.
eval_epsilon: :math:`\\epsilon`-greedy factor during evaluation.
save_interval: Number of epochs before saving models.
experiment_name: Experiment name for logging. If not passed,
the directory name will be ``{class name}_online_{timestamp}``.
with_timestamp: Flag to add timestamp string to the last of
directory name.
logging_steps: Number of steps to log metrics. This will be ignored
if logging_strategy is EPOCH.
logging_strategy: Logging strategy to use.
logger_adapter: LoggerAdapterFactory object.
show_progress: Flag to show progress bar for iterations.
callback: Callable function that takes ``(algo, epoch, total_step)``
, which is called at the end of epochs.
"""
# create default replay buffer
if buffer is None:
buffer = create_fifo_replay_buffer(1000000)
# check action-space
assert_action_space_with_env(self, env)
# setup logger
if experiment_name is None:
experiment_name = self.__class__.__name__ + "_online"
logger = D3RLPyLogger(
adapter_factory=logger_adapter,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
)
# initialize algorithm parameters
build_scalers_with_env(self, env)
# setup algorithm
if self.impl is None:
LOG.debug("Building model...")
self.build_with_env(env)
LOG.debug("Model has been built.")
else:
LOG.warning("Skip building models since they're already built.")
# save hyperparameters
save_config(self, logger)
# switch based on show_progress flag
xrange = trange if show_progress else range
# start training loop
observation, _ = env.reset()
rollout_return = 0.0
for total_step in xrange(1, n_steps + 1):
with logger.measure_time("step"):
# sample exploration action
with logger.measure_time("inference"):
if total_step < random_steps:
action = env.action_space.sample()
elif explorer:
x = observation.reshape((1,) + observation.shape)
action = explorer.sample(self, x, total_step)[0]
else:
action = self.sample_action(
np.expand_dims(observation, axis=0)
)[0]
# step environment
with logger.measure_time("environment_step"):
(
next_observation,
reward,
terminal,
truncated,
_,
) = env.step(action)
rollout_return += float(reward)
clip_episode = terminal or truncated
# store observation
buffer.append(observation, action, float(reward))
# reset if terminated
if clip_episode:
buffer.clip_episode(terminal)
observation, _ = env.reset()
logger.add_metric("rollout_return", rollout_return)
rollout_return = 0.0
else:
observation = next_observation
# psuedo epoch count
epoch = total_step // n_steps_per_epoch
if (
total_step > update_start_step
and buffer.transition_count > self.batch_size
):
if total_step % update_interval == 0:
for _ in range(n_updates): # controls UTD ratio
# sample mini-batch
with logger.measure_time("sample_batch"):
batch = buffer.sample_transition_batch(
self.batch_size
)
# update parameters
with logger.measure_time("algorithm_update"):
loss = self.update(batch)
# record metrics
for name, val in loss.items():
logger.add_metric(name, val)
if (
logging_strategy == LoggingStrategy.STEPS
and total_step % logging_steps == 0
):
logger.commit(epoch, total_step)
# call callback if given
if callback:
callback(self, epoch, total_step)
if epoch > 0 and total_step % n_steps_per_epoch == 0:
# evaluation
if eval_env:
eval_score = evaluate_qlearning_with_environment(
self, eval_env, epsilon=eval_epsilon
)
logger.add_metric("evaluation", eval_score)
if epoch % save_interval == 0:
logger.save_model(total_step, self)
# save metrics
if logging_strategy == LoggingStrategy.EPOCH:
logger.commit(epoch, total_step)
# clip the last episode
buffer.clip_episode(False)
# close logger
logger.close()
[docs] def collect(
self,
env: GymEnv,
buffer: Optional[ReplayBufferBase] = None,
explorer: Optional[Explorer] = None,
deterministic: bool = False,
n_steps: int = 1000000,
show_progress: bool = True,
) -> ReplayBufferBase:
"""Collects data via interaction with environment.
If ``buffer`` is not given, ``ReplayBuffer`` will be internally created.
Args:
env: Fym-like environment.
buffer: Replay buffer.
explorer: Action explorer.
deterministic: Flag to collect data with the greedy policy.
n_steps: Number of total steps to train.
show_progress: Flag to show progress bar for iterations.
Returns:
Replay buffer with the collected data.
"""
# create default replay buffer
if buffer is None:
buffer = create_fifo_replay_buffer(1000000, env=env)
# check action-space
assert_action_space_with_env(self, env)
# initialize algorithm parameters
build_scalers_with_env(self, env)
# setup algorithm
if self.impl is None:
LOG.debug("Building model...")
self.build_with_env(env)
LOG.debug("Model has been built.")
else:
LOG.warning("Skip building models since they're already built.")
# switch based on show_progress flag
xrange = trange if show_progress else range
# start training loop
observation, _ = env.reset()
for total_step in xrange(1, n_steps + 1):
# sample exploration action
if deterministic:
action = self.predict(np.expand_dims(observation, axis=0))[0]
else:
if explorer:
x = observation.reshape((1,) + observation.shape)
action = explorer.sample(self, x, total_step)[0]
else:
action = self.sample_action(
np.expand_dims(observation, axis=0)
)[0]
# step environment
next_observation, reward, terminal, truncated, _ = env.step(action)
clip_episode = terminal or truncated
# store observation
buffer.append(observation, action, float(reward))
# reset if terminated
if clip_episode:
buffer.clip_episode(terminal)
observation, _ = env.reset()
else:
observation = next_observation
# clip the last episode
buffer.clip_episode(False)
return buffer
[docs] def update(self, batch: TransitionMiniBatch) -> Dict[str, float]:
"""Update parameters with mini-batch of data.
Args:
batch: Mini-batch data.
Returns:
Dictionary of metrics.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
torch_batch = TorchMiniBatch.from_batch(
batch=batch,
gamma=self._config.gamma,
compute_returns_to_go=self.need_returns_to_go,
device=self._device,
observation_scaler=self._config.observation_scaler,
action_scaler=self._config.action_scaler,
reward_scaler=self._config.reward_scaler,
)
loss = self._impl.update(torch_batch, self._grad_step)
self._grad_step += 1
return loss
@property
def need_returns_to_go(self) -> bool:
return False
[docs] def copy_policy_from(
self, algo: "QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig]"
) -> None:
"""Copies policy parameters from the given algorithm.
.. code-block:: python
# pretrain with static dataset
cql = d3rlpy.algos.CQL()
cql.fit(dataset, n_steps=100000)
# transfer to online algorithm
sac = d3rlpy.algos.SAC()
sac.create_impl(cql.observation_shape, cql.action_size)
sac.copy_policy_from(cql)
Args:
algo: Algorithm object.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo.impl, QLearningAlgoImplBase)
self._impl.copy_policy_from(algo.impl)
[docs] def copy_policy_optim_from(
self, algo: "QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig]"
) -> None:
"""Copies policy optimizer states from the given algorithm.
.. code-block:: python
# pretrain with static dataset
cql = d3rlpy.algos.CQL()
cql.fit(dataset, n_steps=100000)
# transfer to online algorithm
sac = d3rlpy.algos.SAC()
sac.create_impl(cql.observation_shape, cql.action_size)
sac.copy_policy_optim_from(cql)
Args:
algo: Algorithm object.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo.impl, QLearningAlgoImplBase)
self._impl.copy_policy_optim_from(algo.impl)
[docs] def copy_q_function_from(
self, algo: "QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig]"
) -> None:
"""Copies Q-function parameters from the given algorithm.
.. code-block:: python
# pretrain with static dataset
cql = d3rlpy.algos.CQL()
cql.fit(dataset, n_steps=100000)
# transfer to online algorithmn
sac = d3rlpy.algos.SAC()
sac.create_impl(cql.observation_shape, cql.action_size)
sac.copy_q_function_from(cql)
Args:
algo: Algorithm object.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo.impl, QLearningAlgoImplBase)
self._impl.copy_q_function_from(algo.impl)
[docs] def copy_q_function_optim_from(
self, algo: "QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig]"
) -> None:
"""Copies Q-function optimizer states from the given algorithm.
.. code-block:: python
# pretrain with static dataset
cql = d3rlpy.algos.CQL()
cql.fit(dataset, n_steps=100000)
# transfer to online algorithm
sac = d3rlpy.algos.SAC()
sac.create_impl(cql.observation_shape, cql.action_size)
sac.copy_policy_optim_from(cql)
Args:
algo: Algorithm object.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
assert isinstance(algo.impl, QLearningAlgoImplBase)
self._impl.copy_q_function_optim_from(algo.impl)
[docs] def reset_optimizer_states(self) -> None:
"""Resets optimizer states.
This is especially useful when fine-tuning policies with setting inital
optimizer states.
"""
assert self._impl, IMPL_NOT_INITIALIZED_ERROR
self._impl.reset_optimizer_states()