# pylint: disable=unused-import,too-many-return-statements
import os
import re
from typing import Tuple
from urllib import request
import gym
import numpy as np
from .dataset import MDPDataset
from .envs import ChannelFirst
DATA_DIRECTORY = "d3rlpy_data"
DROPBOX_URL = "https://www.dropbox.com/s"
CARTPOLE_URL = f"{DROPBOX_URL}/l1sdnq3zvoot2um/cartpole.h5?dl=1"
CARTPOLE_RANDOM_URL = f"{DROPBOX_URL}/rwf4pns5x0ku848/cartpole_random.h5?dl=1"
PENDULUM_URL = f"{DROPBOX_URL}/vsiz9pwvshj7sly/pendulum.h5?dl=1"
PENDULUM_RANDOM_URL = f"{DROPBOX_URL}/qldf2vjvvc5thsb/pendulum_random.h5?dl=1"
[docs]def get_cartpole(
create_mask: bool = False, mask_size: int = 1, dataset_type: str = "replay"
) -> Tuple[MDPDataset, gym.Env]:
"""Returns cartpole dataset and environment.
The dataset is automatically downloaded to ``d3rlpy_data/cartpole.h5`` if
it does not exist.
Args:
create_mask: flag to create binary mask for bootstrapping.
mask_size: ensemble size for binary mask.
dataset_type: dataset type. Available options are
``['replay', 'random']``.
Returns:
tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
if dataset_type == "replay":
url = CARTPOLE_URL
file_name = "cartpole_replay.h5"
elif dataset_type == "random":
url = CARTPOLE_RANDOM_URL
file_name = "cartpole_random.h5"
else:
raise ValueError(f"Invalid dataset_type: {dataset_type}.")
data_path = os.path.join(DATA_DIRECTORY, file_name)
# download dataset
if not os.path.exists(data_path):
os.makedirs(DATA_DIRECTORY, exist_ok=True)
print("Donwloading cartpole.pkl into %s..." % data_path)
request.urlretrieve(url, data_path)
# load dataset
dataset = MDPDataset.load(
data_path, create_mask=create_mask, mask_size=mask_size
)
# environment
env = gym.make("CartPole-v0")
return dataset, env
[docs]def get_pendulum(
create_mask: bool = False,
mask_size: int = 1,
dataset_type: str = "replay",
) -> Tuple[MDPDataset, gym.Env]:
"""Returns pendulum dataset and environment.
The dataset is automatically downloaded to ``d3rlpy_data/pendulum.h5`` if
it does not exist.
Args:
create_mask: flag to create binary mask for bootstrapping.
mask_size: ensemble size for binary mask.
dataset_type: dataset type. Available options are
``['replay', 'random']``.
Returns:
tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
if dataset_type == "replay":
url = PENDULUM_URL
file_name = "pendulum_replay.h5"
elif dataset_type == "random":
url = PENDULUM_RANDOM_URL
file_name = "pendulum_random.h5"
else:
raise ValueError(f"Invalid dataset_type: {dataset_type}.")
data_path = os.path.join(DATA_DIRECTORY, file_name)
if not os.path.exists(data_path):
os.makedirs(DATA_DIRECTORY, exist_ok=True)
print("Donwloading pendulum.pkl into %s..." % data_path)
request.urlretrieve(url, data_path)
# load dataset
dataset = MDPDataset.load(
data_path, create_mask=create_mask, mask_size=mask_size
)
# environment
env = gym.make("Pendulum-v0")
return dataset, env
[docs]def get_pybullet(
env_name: str, create_mask: bool = False, mask_size: int = 1
) -> Tuple[MDPDataset, gym.Env]:
"""Returns pybullet dataset and envrironment.
The dataset is provided through d4rl-pybullet. See more details including
available dataset from its GitHub page.
.. code-block:: python
from d3rlpy.datasets import get_pybullet
dataset, env = get_pybullet('hopper-bullet-mixed-v0')
References:
* https://github.com/takuseno/d4rl-pybullet
Args:
env_name: environment id of d4rl-pybullet dataset.
create_mask: flag to create binary mask for bootstrapping.
mask_size: ensemble size for binary mask.
Returns:
tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
try:
import d4rl_pybullet # type: ignore
env = gym.make(env_name)
dataset = MDPDataset(
create_mask=create_mask, mask_size=mask_size, **env.get_dataset()
)
return dataset, env
except ImportError as e:
raise ImportError(
"d4rl-pybullet is not installed.\n"
"pip install git+https://github.com/takuseno/d4rl-pybullet"
) from e
[docs]def get_atari(
env_name: str, create_mask: bool = False, mask_size: int = 1
) -> Tuple[MDPDataset, gym.Env]:
"""Returns atari dataset and envrironment.
The dataset is provided through d4rl-atari. See more details including
available dataset from its GitHub page.
.. code-block:: python
from d3rlpy.datasets import get_atari
dataset, env = get_atari('breakout-mixed-v0')
References:
* https://github.com/takuseno/d4rl-atari
Args:
env_name: environment id of d4rl-atari dataset.
create_mask: flag to create binary mask for bootstrapping.
mask_size: ensemble size for binary mask.
Returns:
tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
try:
import d4rl_atari # type: ignore
env = ChannelFirst(gym.make(env_name))
dataset = MDPDataset(
discrete_action=True,
create_mask=create_mask,
mask_size=mask_size,
**env.get_dataset(),
)
return dataset, env
except ImportError as e:
raise ImportError(
"d4rl-atari is not installed.\n"
"pip install git+https://github.com/takuseno/d4rl-atari"
) from e
[docs]def get_d4rl(
env_name: str, create_mask: bool = False, mask_size: int = 1
) -> Tuple[MDPDataset, gym.Env]:
"""Returns d4rl dataset and envrironment.
The dataset is provided through d4rl.
.. code-block:: python
from d3rlpy.datasets import get_d4rl
dataset, env = get_d4rl('hopper-medium-v0')
References:
* `Fu et al., D4RL: Datasets for Deep Data-Driven Reinforcement
Learning. <https://arxiv.org/abs/2004.07219>`_
* https://github.com/rail-berkeley/d4rl
Args:
env_name: environment id of d4rl dataset.
create_mask: flag to create binary mask for bootstrapping.
mask_size: ensemble size for binary mask.
Returns:
tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
try:
import d4rl # type: ignore
env = gym.make(env_name)
dataset = env.get_dataset()
observations = []
actions = []
rewards = []
terminals = []
episode_terminals = []
episode_step = 0
cursor = 0
dataset_size = dataset["observations"].shape[0]
while cursor < dataset_size:
# collect data for step=t
observation = dataset["observations"][cursor]
action = dataset["actions"][cursor]
if episode_step == 0:
reward = 0.0
else:
reward = dataset["rewards"][cursor - 1]
observations.append(observation)
actions.append(action)
rewards.append(reward)
terminals.append(0.0)
# skip adding the last step when timeout
if dataset["timeouts"][cursor]:
episode_terminals.append(1.0)
episode_step = 0
cursor += 1
continue
episode_terminals.append(0.0)
episode_step += 1
if dataset["terminals"][cursor]:
# collect data for step=t+1
dummy_observation = observation.copy()
dummy_action = action.copy()
next_reward = dataset["rewards"][cursor]
# the last observation is rarely used
observations.append(dummy_observation)
actions.append(dummy_action)
rewards.append(next_reward)
terminals.append(1.0)
episode_terminals.append(1.0)
episode_step = 0
cursor += 1
mdp_dataset = MDPDataset(
observations=np.array(observations, dtype=np.float32),
actions=np.array(actions, dtype=np.float32),
rewards=np.array(rewards, dtype=np.float32),
terminals=np.array(terminals, dtype=np.float32),
episode_terminals=np.array(episode_terminals, dtype=np.float32),
create_mask=create_mask,
mask_size=mask_size,
)
return mdp_dataset, env
except ImportError as e:
raise ImportError(
"d4rl is not installed.\n"
"pip install git+https://github.com/rail-berkeley/d4rl"
) from e
ATARI_GAMES = [
"adventure",
"air-raid",
"alien",
"amidar",
"assault",
"asterix",
"asteroids",
"atlantis",
"bank-heist",
"battle-zone",
"beam-rider",
"berzerk",
"bowling",
"boxing",
"breakout",
"carnival",
"centipede",
"chopper-command",
"crazy-climber",
"defender",
"demon-attack",
"double-dunk",
"elevator-action",
"enduro",
"fishing-derby",
"freeway",
"frostbite",
"gopher",
"gravitar",
"hero",
"ice-hockey",
"jamesbond",
"journey-escape",
"kangaroo",
"krull",
"kung-fu-master",
"montezuma-revenge",
"ms-pacman",
"name-this-game",
"phoenix",
"pitfall",
"pong",
"pooyan",
"private-eye",
"qbert",
"riverraid",
"road-runner",
"robotank",
"seaquest",
"skiing",
"solaris",
"space-invaders",
"star-gunner",
"tennis",
"time-pilot",
"tutankham",
"up-n-down",
"venture",
"video-pinball",
"wizard-of-wor",
"yars-revenge",
"zaxxon",
]
[docs]def get_dataset(
env_name: str, create_mask: bool = False, mask_size: int = 1
) -> Tuple[MDPDataset, gym.Env]:
"""Returns dataset and envrironment by guessing from name.
This function returns dataset by matching name with the following datasets.
- cartpole-replay
- cartpole-random
- pendulum-replay
- pendulum-random
- d4rl-pybullet
- d4rl-atari
- d4rl
.. code-block:: python
import d3rlpy
# cartpole dataset
dataset, env = d3rlpy.datasets.get_dataset('cartpole')
# pendulum dataset
dataset, env = d3rlpy.datasets.get_dataset('pendulum')
# d4rl-pybullet dataset
dataset, env = d3rlpy.datasets.get_dataset('hopper-bullet-mixed-v0')
# d4rl-atari dataset
dataset, env = d3rlpy.datasets.get_dataset('breakout-mixed-v0')
# d4rl dataset
dataset, env = d3rlpy.datasets.get_dataset('hopper-medium-v0')
Args:
env_name: environment id of the dataset.
create_mask: flag to create binary mask for bootstrapping.
mask_size: ensemble size for binary mask.
Returns:
tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment.
"""
if env_name == "cartpole-replay":
return get_cartpole(create_mask, mask_size, dataset_type="replay")
elif env_name == "cartpole-random":
return get_cartpole(create_mask, mask_size, dataset_type="random")
elif env_name == "pendulum-replay":
return get_pendulum(create_mask, mask_size, dataset_type="replay")
elif env_name == "pendulum-random":
return get_pendulum(create_mask, mask_size, dataset_type="random")
elif re.match(r"^bullet-.+$", env_name):
return get_d4rl(env_name, create_mask, mask_size)
elif re.match(r"^.+-bullet-.+$", env_name):
return get_pybullet(env_name, create_mask, mask_size)
elif re.match(r"hopper|halfcheetah|walker|ant", env_name):
return get_d4rl(env_name, create_mask, mask_size)
elif re.match(re.compile("|".join(ATARI_GAMES)), env_name):
return get_atari(env_name, create_mask, mask_size)
raise ValueError(f"Unrecognized env_name: {env_name}.")