Data Collection¶
d3rlpy provides APIs to support data collection from environments. This feature is specifically useful if you want to build your own original datasets for research or practice purposes.
Prepare Environment¶
d3rlpy supports environments with OpenAI Gym interface. In this tutorial, let’s use simple CartPole environment.
import gym
env = gym.make("CartPole-v1")
Data Collection with Random Policy¶
If you want to collect experiences with uniformly random policy, you can use
RandomPolicy
and DiscreteRandomPolicy
.
This procedure corresponds to random
datasets in D4RL.
import d3rlpy
# setup algorithm
random_policy = d3rlpy.algos.DiscreteRandomPolicyConfig().create()
# prepare experience replay buffer
buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=100000, env=env)
# start data collection
random_policy.collect(env, buffer, n_steps=100000)
# save ReplayBuffer
with open("random_policy_dataset.h5", "w+b") as f:
buffer.dump(f)
Data Collection with Trained Policy¶
If you want to collect experiences with previously trained policy, you can
still use the same set of APIs.
Here, let’s say a DQN model is saved as dqn_model.d3
.
This procedure corresponds to medium
datasets in D4RL.
# prepare pretrained algorithm
dqn = d3rlpy.load_learnable("dqn_model.d3")
# prepare experience replay buffer
buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=100000, env=env)
# start data collection
dqn.collect(env, buffer, n_steps=100000)
# save ReplayBuffer
with open("trained_policy_dataset.h5", "w+b") as f:
buffer.dump(f)
Data Collection while Training Policy¶
If you want to use experiences collected during training to build a new dataset,
you can simply use fit_online
and save the dataset.
This procedure corresponds to replay
datasets in D4RL.
# setup algorithm
dqn = d3rlpy.algos.DQNConfig().create()
# prepare experience replay buffer
buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=100000, env=env)
# prepare exploration strategy if necessary
explorer = d3rlpy.algos.ConstantEpsilonGreedy(0.3)
# start data collection
dqn.fit_online(env, buffer, explorer, n_steps=100000)
# save ReplayBuffer
with open("replay_dataset.h5", "w+b") as f:
buffer.dump(f)