import numpy as np
import gym
import os
import urllib
import pickle
from .dataset import MDPDataset
DATA_DIRECTORY = 'd3rlpy_data'
CARTPOLE_URL = 'https://www.dropbox.com/s/2tmo7ul00268l3e/cartpole.pkl?dl=1'
PENDULUM_URL = 'https://www.dropbox.com/s/90z7a84ngndrqt4/pendulum.pkl?dl=1'
[docs]def get_cartpole():
""" Returns cartpole dataset and environment.
The dataset is automatically downloaded to `d3rlpy_data/cartpole.pkl` if it
does not exist.
Returns:
tuple: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
data_path = os.path.join(DATA_DIRECTORY, 'cartpole.pkl')
# download dataset
if not os.path.exists(data_path):
os.makedirs(DATA_DIRECTORY, exist_ok=True)
print('Donwloading cartpole.pkl into %s...' % data_path)
urllib.request.urlretrieve(CARTPOLE_URL, data_path)
# load dataset
with open(data_path, 'rb') as f:
observations, actions, rewards, terminals = pickle.load(f)
# environment
env = gym.make('CartPole-v0')
dataset = MDPDataset(observations=np.array(observations, dtype=np.float32),
actions=actions,
rewards=rewards,
terminals=terminals,
discrete_action=True)
return dataset, env
[docs]def get_pendulum():
""" Returns pendulum dataset and environment.
The dataset is automatically downloaded to `d3rlpy_data/pendulum.pkl` if it
does not exist.
Returns:
tuple: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
data_path = os.path.join(DATA_DIRECTORY, 'pendulum.pkl')
if not os.path.exists(data_path):
os.makedirs(DATA_DIRECTORY, exist_ok=True)
print('Donwloading pendulum.pkl into %s...' % data_path)
urllib.request.urlretrieve(PENDULUM_URL, data_path)
# load dataset
with open(data_path, 'rb') as f:
observations, actions, rewards, terminals = pickle.load(f)
# environment
env = gym.make('Pendulum-v0')
dataset = MDPDataset(observations=np.array(observations, dtype=np.float32),
actions=actions,
rewards=rewards,
terminals=terminals)
return dataset, env
[docs]def get_pybullet(env_name):
""" Returns pybullet dataset and envrironment.
The dataset is provided through d4rl-pybullet. See more details including
available dataset from its GitHub page.
.. code-block:: python
from d3rlpy.datasets import get_pybullet
dataset, env = get_pybullet('hopper-bullet-mixed-v0')
References:
* https://github.com/takuseno/d4rl-pybullet
Args:
env_name (str): environment id of d4rl-pybullet dataset.
Returns:
tuple: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
try:
import d4rl_pybullet
env = gym.make(env_name)
dataset = MDPDataset(**env.get_dataset())
return dataset, env
except ImportError:
raise ImportError(
'd4rl-pybullet is not installed.\n' \
'pip install git+https://github.com/takuseno/d4rl-pybullet')
[docs]def get_atari(env_name):
""" Returns atari dataset and envrironment.
The dataset is provided through d4rl-atari. See more details including
available dataset from its GitHub page.
.. code-block:: python
from d3rlpy.datasets import get_atari
dataset, env = get_atari('breakout-mixed-v0')
References:
* https://github.com/takuseno/d4rl-atari
Args:
env_name (str): environment id of d4rl-atari dataset.
Returns:
tuple: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
try:
import d4rl_atari
env = gym.make(env_name, stack=False)
dataset = MDPDataset(**env.get_dataset(), discrete_action=True)
return dataset, env
except ImportError:
raise ImportError(
'd4rl-atari is not installed.\n' \
'pip install git+https://github.com/takuseno/d4rl-atari')