Source code for d3rlpy.datasets

# pylint: disable=unused-import,too-many-return-statements

import os
import random
import re
from typing import Any, Dict, Optional, Tuple
from urllib import request

import gym
import numpy as np
from gym.wrappers.time_limit import TimeLimit

from .dataset import (
    Episode,
    EpisodeGenerator,
    FrameStackTransitionPicker,
    InfiniteBuffer,
    MDPDataset,
    ReplayBuffer,
    TrajectorySlicerProtocol,
    TransitionPickerProtocol,
    create_infinite_replay_buffer,
    load_v1,
)
from .envs import ChannelFirst, FrameStack
from .logging import LOG

__all__ = [
    "DATA_DIRECTORY",
    "DROPBOX_URL",
    "CARTPOLE_URL",
    "CARTPOLE_RANDOM_URL",
    "PENDULUM_URL",
    "PENDULUM_RANDOM_URL",
    "get_cartpole",
    "get_pendulum",
    "get_atari",
    "get_atari_transitions",
    "get_d4rl",
    "get_dataset",
]

DATA_DIRECTORY = "d3rlpy_data"
DROPBOX_URL = "https://www.dropbox.com/s"
CARTPOLE_URL = f"{DROPBOX_URL}/uep0lzlhxpi79pd/cartpole_v1.1.0.h5?dl=1"
CARTPOLE_RANDOM_URL = f"{DROPBOX_URL}/4lgai7tgj84cbov/cartpole_random_v1.1.0.h5?dl=1"  # pylint: disable=line-too-long
PENDULUM_URL = f"{DROPBOX_URL}/ukkucouzys0jkfs/pendulum_v1.1.0.h5?dl=1"
PENDULUM_RANDOM_URL = f"{DROPBOX_URL}/hhbq9i6ako24kzz/pendulum_random_v1.1.0.h5?dl=1"  # pylint: disable=line-too-long


[docs]def get_cartpole( dataset_type: str = "replay", transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, int]]: """Returns cartpole dataset and environment. The dataset is automatically downloaded to ``d3rlpy_data/cartpole.h5`` if it does not exist. Args: dataset_type: dataset type. Available options are ``['replay', 'random']``. transition_picker: TransitionPickerProtocol object. trajectory_slicer: TrajectorySlicerProtocol object. render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. """ if dataset_type == "replay": url = CARTPOLE_URL file_name = "cartpole_replay_v1.1.0.h5" elif dataset_type == "random": url = CARTPOLE_RANDOM_URL file_name = "cartpole_random_v1.1.0.h5" else: raise ValueError(f"Invalid dataset_type: {dataset_type}.") data_path = os.path.join(DATA_DIRECTORY, file_name) # download dataset if not os.path.exists(data_path): os.makedirs(DATA_DIRECTORY, exist_ok=True) print(f"Downloading cartpole.pkl into {data_path}...") request.urlretrieve(url, data_path) # load dataset with open(data_path, "rb") as f: episodes = load_v1(f) dataset = ReplayBuffer( InfiniteBuffer(), episodes=episodes, transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, ) # environment env = gym.make("CartPole-v1", render_mode=render_mode) return dataset, env
[docs]def get_pendulum( dataset_type: str = "replay", transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, np.ndarray]]: """Returns pendulum dataset and environment. The dataset is automatically downloaded to ``d3rlpy_data/pendulum.h5`` if it does not exist. Args: dataset_type: dataset type. Available options are ``['replay', 'random']``. transition_picker: TransitionPickerProtocol object. trajectory_slicer: TrajectorySlicerProtocol object. render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. """ if dataset_type == "replay": url = PENDULUM_URL file_name = "pendulum_replay_v1.1.0.h5" elif dataset_type == "random": url = PENDULUM_RANDOM_URL file_name = "pendulum_random_v1.1.0.h5" else: raise ValueError(f"Invalid dataset_type: {dataset_type}.") data_path = os.path.join(DATA_DIRECTORY, file_name) if not os.path.exists(data_path): os.makedirs(DATA_DIRECTORY, exist_ok=True) print(f"Donwloading pendulum.pkl into {data_path}...") request.urlretrieve(url, data_path) # load dataset with open(data_path, "rb") as f: episodes = load_v1(f) dataset = ReplayBuffer( InfiniteBuffer(), episodes=episodes, transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, ) # environment env = gym.make("Pendulum-v1", render_mode=render_mode) return dataset, env
[docs]def get_atari( env_name: str, num_stack: Optional[int] = None, render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, int]]: """Returns atari dataset and envrironment. The dataset is provided through d4rl-atari. See more details including available dataset from its GitHub page. .. code-block:: python from d3rlpy.datasets import get_atari dataset, env = get_atari('breakout-mixed-v0') References: * https://github.com/takuseno/d4rl-atari Args: env_name: environment id of d4rl-atari dataset. num_stack: the number of frames to stack (only applied to env). render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. """ try: import d4rl_atari # type: ignore env = gym.make(env_name, render_mode=render_mode) raw_dataset = env.get_dataset() # type: ignore episode_generator = EpisodeGenerator(**raw_dataset) dataset = create_infinite_replay_buffer( episodes=episode_generator(), transition_picker=FrameStackTransitionPicker(num_stack or 1), trajectory_slicer=None, ) if num_stack: env = FrameStack(env, num_stack=num_stack) else: env = ChannelFirst(env) return dataset, env except ImportError as e: raise ImportError( "d4rl-atari is not installed.\n" "$ d3rlpy install d4rl_atari" ) from e
[docs]def get_atari_transitions( game_name: str, fraction: float = 0.01, index: int = 0, num_stack: Optional[int] = None, render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, int]]: """Returns atari dataset as a list of Transition objects and envrironment. The dataset is provided through d4rl-atari. The difference from ``get_atari`` function is that this function will sample transitions from all epochs. This function is necessary for reproducing Atari experiments. .. code-block:: python from d3rlpy.datasets import get_atari_transitions # get 1% of transitions from all epochs (1M x 50 epoch x 1% = 0.5M) dataset, env = get_atari_transitions('breakout', fraction=0.01) References: * https://github.com/takuseno/d4rl-atari Args: game_name: Atari 2600 game name in lower_snake_case. fraction: fraction of sampled transitions. index: index to specify which trial to load. num_stack: the number of frames to stack (only applied to env). render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of a list of :class:`d3rlpy.dataset.Transition` and gym environment. """ try: import d4rl_atari # each epoch consists of 1M steps num_transitions_per_epoch = int(1000000 * fraction) copied_episodes = [] for i in range(50): env_name = f"{game_name}-epoch-{i + 1}-v{index}" LOG.info(f"Collecting {env_name}...") env = gym.make( env_name, sticky_action=True, render_mode=render_mode, ) raw_dataset = env.get_dataset() # type: ignore episode_generator = EpisodeGenerator(**raw_dataset) episodes = list(episode_generator()) # copy episode data to release memory of unused data random.shuffle(episodes) num_data = 0 for episode in episodes: if num_data >= num_transitions_per_epoch: break copied_episode = Episode( observations=episode.observations.copy(), # type: ignore actions=episode.actions.copy(), rewards=episode.rewards.copy(), terminated=episode.terminated, ) # trim episode if num_data + copied_episode.size() > num_transitions_per_epoch: end = num_transitions_per_epoch - num_data copied_episode = Episode( observations=copied_episode.observations[:end], actions=copied_episode.actions[:end], rewards=copied_episode.rewards[:end], terminated=False, ) copied_episodes.append(copied_episode) num_data += copied_episode.size() dataset = ReplayBuffer( InfiniteBuffer(), episodes=copied_episodes, transition_picker=FrameStackTransitionPicker(num_stack or 1), ) if num_stack: env = FrameStack(env, num_stack=num_stack) else: env = ChannelFirst(env) return dataset, env except ImportError as e: raise ImportError( "d4rl-atari is not installed.\n" "$ d3rlpy install d4rl_atari" ) from e
[docs]def get_d4rl( env_name: str, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[np.ndarray, np.ndarray]]: """Returns d4rl dataset and envrironment. The dataset is provided through d4rl. .. code-block:: python from d3rlpy.datasets import get_d4rl dataset, env = get_d4rl('hopper-medium-v0') References: * `Fu et al., D4RL: Datasets for Deep Data-Driven Reinforcement Learning. <https://arxiv.org/abs/2004.07219>`_ * https://github.com/rail-berkeley/d4rl Args: env_name: environment id of d4rl dataset. transition_picker: TransitionPickerProtocol object. trajectory_slicer: TrajectorySlicerProtocol object. render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. """ try: import d4rl # type: ignore env = gym.make(env_name) raw_dataset: Dict[str, np.ndarray] = env.get_dataset() # type: ignore observations = raw_dataset["observations"] actions = raw_dataset["actions"] rewards = raw_dataset["rewards"] terminals = raw_dataset["terminals"] timeouts = raw_dataset["timeouts"] dataset = MDPDataset( observations=observations, actions=actions, rewards=rewards, terminals=terminals, timeouts=timeouts, transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, ) # wrapped by NormalizedBoxEnv that is incompatible with newer Gym unwrapped_env: gym.Env[Any, Any] = env.env.env.env.wrapped_env # type: ignore unwrapped_env.render_mode = render_mode # overwrite return dataset, TimeLimit(unwrapped_env, max_episode_steps=1000) except ImportError as e: raise ImportError( "d4rl is not installed.\n" "$ d3rlpy install d4rl" ) from e
ATARI_GAMES = [ "adventure", "air-raid", "alien", "amidar", "assault", "asterix", "asteroids", "atlantis", "bank-heist", "battle-zone", "beam-rider", "berzerk", "bowling", "boxing", "breakout", "carnival", "centipede", "chopper-command", "crazy-climber", "defender", "demon-attack", "double-dunk", "elevator-action", "enduro", "fishing-derby", "freeway", "frostbite", "gopher", "gravitar", "hero", "ice-hockey", "jamesbond", "journey-escape", "kangaroo", "krull", "kung-fu-master", "montezuma-revenge", "ms-pacman", "name-this-game", "phoenix", "pitfall", "pong", "pooyan", "private-eye", "qbert", "riverraid", "road-runner", "robotank", "seaquest", "skiing", "solaris", "space-invaders", "star-gunner", "tennis", "time-pilot", "tutankham", "up-n-down", "venture", "video-pinball", "wizard-of-wor", "yars-revenge", "zaxxon", ]
[docs]def get_dataset( env_name: str, transition_picker: Optional[TransitionPickerProtocol] = None, trajectory_slicer: Optional[TrajectorySlicerProtocol] = None, render_mode: Optional[str] = None, ) -> Tuple[ReplayBuffer, gym.Env[Any, Any]]: """Returns dataset and envrironment by guessing from name. This function returns dataset by matching name with the following datasets. - cartpole-replay - cartpole-random - pendulum-replay - pendulum-random - d4rl-pybullet - d4rl-atari - d4rl .. code-block:: python import d3rlpy # cartpole dataset dataset, env = d3rlpy.datasets.get_dataset('cartpole') # pendulum dataset dataset, env = d3rlpy.datasets.get_dataset('pendulum') # d4rl-atari dataset dataset, env = d3rlpy.datasets.get_dataset('breakout-mixed-v0') # d4rl dataset dataset, env = d3rlpy.datasets.get_dataset('hopper-medium-v0') Args: env_name: environment id of the dataset. transition_picker: TransitionPickerProtocol object. trajectory_slicer: TrajectorySlicerProtocol object. render_mode: Mode of rendering (``human``, ``rgb_array``). Returns: tuple of :class:`d3rlpy.dataset.ReplayBuffer` and gym environment. """ if env_name == "cartpole-replay": return get_cartpole( dataset_type="replay", transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, render_mode=render_mode, ) elif env_name == "cartpole-random": return get_cartpole( dataset_type="random", transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, render_mode=render_mode, ) elif env_name == "pendulum-replay": return get_pendulum( dataset_type="replay", transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, render_mode=render_mode, ) elif env_name == "pendulum-random": return get_pendulum( dataset_type="random", transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, render_mode=render_mode, ) elif re.match(r"^bullet-.+$", env_name): return get_d4rl( env_name, transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, render_mode=render_mode, ) elif re.match(r"hopper|halfcheetah|walker|ant", env_name): return get_d4rl( env_name, transition_picker=transition_picker, trajectory_slicer=trajectory_slicer, render_mode=render_mode, ) raise ValueError(f"Unrecognized env_name: {env_name}.")