Source code for d3rlpy.datasets

# pylint: disable=unused-import

import urllib.request as request
import os
import pickle
from typing import Tuple

import numpy as np
import gym

from .dataset import MDPDataset

DATA_DIRECTORY = "d3rlpy_data"
CARTPOLE_URL = "https://www.dropbox.com/s/2tmo7ul00268l3e/cartpole.pkl?dl=1"
PENDULUM_URL = "https://www.dropbox.com/s/90z7a84ngndrqt4/pendulum.pkl?dl=1"


[docs]def get_cartpole() -> Tuple[MDPDataset, gym.Env]: """Returns cartpole dataset and environment. The dataset is automatically downloaded to `d3rlpy_data/cartpole.pkl` if it does not exist. Returns: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment. """ data_path = os.path.join(DATA_DIRECTORY, "cartpole.pkl") # download dataset if not os.path.exists(data_path): os.makedirs(DATA_DIRECTORY, exist_ok=True) print("Donwloading cartpole.pkl into %s..." % data_path) request.urlretrieve(CARTPOLE_URL, data_path) # load dataset with open(data_path, "rb") as f: observations, actions, rewards, terminals = pickle.load(f) # environment env = gym.make("CartPole-v0") dataset = MDPDataset( observations=np.array(observations, dtype=np.float32), actions=actions, rewards=rewards, terminals=terminals, discrete_action=True, ) return dataset, env
[docs]def get_pendulum() -> Tuple[MDPDataset, gym.Env]: """Returns pendulum dataset and environment. The dataset is automatically downloaded to `d3rlpy_data/pendulum.pkl` if it does not exist. Returns: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment. """ data_path = os.path.join(DATA_DIRECTORY, "pendulum.pkl") if not os.path.exists(data_path): os.makedirs(DATA_DIRECTORY, exist_ok=True) print("Donwloading pendulum.pkl into %s..." % data_path) request.urlretrieve(PENDULUM_URL, data_path) # load dataset with open(data_path, "rb") as f: observations, actions, rewards, terminals = pickle.load(f) # environment env = gym.make("Pendulum-v0") dataset = MDPDataset( observations=np.array(observations, dtype=np.float32), actions=actions, rewards=rewards, terminals=terminals, ) return dataset, env
[docs]def get_pybullet(env_name: str) -> Tuple[MDPDataset, gym.Env]: """Returns pybullet dataset and envrironment. The dataset is provided through d4rl-pybullet. See more details including available dataset from its GitHub page. .. code-block:: python from d3rlpy.datasets import get_pybullet dataset, env = get_pybullet('hopper-bullet-mixed-v0') References: * https://github.com/takuseno/d4rl-pybullet Args: env_name: environment id of d4rl-pybullet dataset. Returns: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment. """ try: import d4rl_pybullet # type: ignore env = gym.make(env_name) dataset = MDPDataset(**env.get_dataset()) return dataset, env except ImportError as e: raise ImportError( "d4rl-pybullet is not installed.\n" "pip install git+https://github.com/takuseno/d4rl-pybullet" ) from e
[docs]def get_atari(env_name: str) -> Tuple[MDPDataset, gym.Env]: """Returns atari dataset and envrironment. The dataset is provided through d4rl-atari. See more details including available dataset from its GitHub page. .. code-block:: python from d3rlpy.datasets import get_atari dataset, env = get_atari('breakout-mixed-v0') References: * https://github.com/takuseno/d4rl-atari Args: env_name: environment id of d4rl-atari dataset. Returns: tuple of :class:`d3rlpy.dataset.MDPDataset` and gym environment. """ try: import d4rl_atari # type: ignore env = gym.make(env_name, stack=False) dataset = MDPDataset(discrete_action=True, **env.get_dataset()) return dataset, env except ImportError as e: raise ImportError( "d4rl-atari is not installed.\n" "pip install git+https://github.com/takuseno/d4rl-atari" ) from e