Source code for d3rlpy.datasets

# pylint: disable=unused-import

import urllib.request as request
import os
import pickle
from typing import Tuple

import numpy as np
import gym

from .dataset import MDPDataset
from .envs import ChannelFirst

DATA_DIRECTORY = "d3rlpy_data"
CARTPOLE_URL = "https://www.dropbox.com/s/2tmo7ul00268l3e/cartpole.pkl?dl=1"
PENDULUM_URL = "https://www.dropbox.com/s/90z7a84ngndrqt4/pendulum.pkl?dl=1"


[docs]def get_cartpole( create_mask: bool = False, mask_size: int = 1 ) -> Tuple[MDPDataset, gym.Env]: """Returns cartpole dataset and environment. The dataset is automatically downloaded to `d3rlpy_data/cartpole.pkl` if it does not exist. Args: create_mask: flag to create binary mask for bootstrapping. mask_size: ensemble size for binary mask. Returns: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment. """ data_path = os.path.join(DATA_DIRECTORY, "cartpole.pkl") # download dataset if not os.path.exists(data_path): os.makedirs(DATA_DIRECTORY, exist_ok=True) print("Donwloading cartpole.pkl into %s..." % data_path) request.urlretrieve(CARTPOLE_URL, data_path) # load dataset with open(data_path, "rb") as f: observations, actions, rewards, terminals = pickle.load(f) # environment env = gym.make("CartPole-v0") dataset = MDPDataset( observations=np.array(observations, dtype=np.float32), actions=actions, rewards=rewards, terminals=terminals, discrete_action=True, create_mask=create_mask, mask_size=mask_size, ) return dataset, env
[docs]def get_pendulum( create_mask: bool = False, mask_size: int = 1 ) -> Tuple[MDPDataset, gym.Env]: """Returns pendulum dataset and environment. The dataset is automatically downloaded to `d3rlpy_data/pendulum.pkl` if it does not exist. Args: create_mask: flag to create binary mask for bootstrapping. mask_size: ensemble size for binary mask. Returns: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment. """ data_path = os.path.join(DATA_DIRECTORY, "pendulum.pkl") if not os.path.exists(data_path): os.makedirs(DATA_DIRECTORY, exist_ok=True) print("Donwloading pendulum.pkl into %s..." % data_path) request.urlretrieve(PENDULUM_URL, data_path) # load dataset with open(data_path, "rb") as f: observations, actions, rewards, terminals = pickle.load(f) # environment env = gym.make("Pendulum-v0") dataset = MDPDataset( observations=np.array(observations, dtype=np.float32), actions=actions, rewards=rewards, terminals=terminals, create_mask=create_mask, mask_size=mask_size, ) return dataset, env
[docs]def get_pybullet( env_name: str, create_mask: bool = False, mask_size: int = 1 ) -> Tuple[MDPDataset, gym.Env]: """Returns pybullet dataset and envrironment. The dataset is provided through d4rl-pybullet. See more details including available dataset from its GitHub page. .. code-block:: python from d3rlpy.datasets import get_pybullet dataset, env = get_pybullet('hopper-bullet-mixed-v0') References: * https://github.com/takuseno/d4rl-pybullet Args: env_name: environment id of d4rl-pybullet dataset. create_mask: flag to create binary mask for bootstrapping. mask_size: ensemble size for binary mask. Returns: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment. """ try: import d4rl_pybullet # type: ignore env = gym.make(env_name) dataset = MDPDataset( create_mask=create_mask, mask_size=mask_size, **env.get_dataset() ) return dataset, env except ImportError as e: raise ImportError( "d4rl-pybullet is not installed.\n" "pip install git+https://github.com/takuseno/d4rl-pybullet" ) from e
[docs]def get_atari( env_name: str, create_mask: bool = False, mask_size: int = 1 ) -> Tuple[MDPDataset, gym.Env]: """Returns atari dataset and envrironment. The dataset is provided through d4rl-atari. See more details including available dataset from its GitHub page. .. code-block:: python from d3rlpy.datasets import get_atari dataset, env = get_atari('breakout-mixed-v0') References: * https://github.com/takuseno/d4rl-atari Args: env_name: environment id of d4rl-atari dataset. create_mask: flag to create binary mask for bootstrapping. mask_size: ensemble size for binary mask. Returns: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment. """ try: import d4rl_atari # type: ignore env = ChannelFirst(gym.make(env_name)) dataset = MDPDataset( discrete_action=True, create_mask=create_mask, mask_size=mask_size, **env.get_dataset() ) return dataset, env except ImportError as e: raise ImportError( "d4rl-atari is not installed.\n" "pip install git+https://github.com/takuseno/d4rl-atari" ) from e
[docs]def get_d4rl( env_name: str, create_mask: bool = False, mask_size: int = 1 ) -> Tuple[MDPDataset, gym.Env]: """Returns d4rl dataset and envrironment. The dataset is provided through d4rl. .. code-block:: python from d3rlpy.datasets import get_d4rl dataset, env = get_d4rl('hopper-medium-v0') References: * `Fu et al., D4RL: Datasets for Deep Data-Driven Reinforcement Learning. <https://arxiv.org/abs/2004.07219>`_ * https://github.com/rail-berkeley/d4rl Args: env_name: environment id of d4rl dataset. create_mask: flag to create binary mask for bootstrapping. mask_size: ensemble size for binary mask. Returns: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment. """ try: import d4rl # type: ignore env = gym.make(env_name) dataset = env.get_dataset() observations = dataset["observations"] actions = dataset["actions"] rewards = dataset["rewards"] terminals = np.logical_and( dataset["terminals"], np.logical_not(dataset["timeouts"]) ) episode_terminals = np.logical_or( dataset["terminals"], dataset["timeouts"] ) mdp_dataset = MDPDataset( observations=observations, actions=actions, rewards=rewards, terminals=terminals, episode_terminals=episode_terminals, create_mask=create_mask, mask_size=mask_size, ) return mdp_dataset, env except ImportError as e: raise ImportError( "d4rl is not installed.\n" "pip install git+https://github.com/rail-berkeley/d4rl" ) from e