d3rlpy - A data-driven deep reinforcement learning library as an out-of-the-box tool.

d3rlpy is a easy-to-use data-driven deep reinforcement learning library.

d3rlpy provides state-of-the-art data-driven deep reinforcement learning algorithms through out-of-the-box scikit-learn-style APIs. Unlike other RL libraries, the provided algorithms can achieve extremely powerful performance beyond the paper via several tweaks.

Getting Started

This tutorial is also available on Google Colaboratory

Install

First of all, let’s install d3rlpy on your machine:

$ pip install d3rlpy

Note

d3rlpy supports Python 3.6+. Make sure which version you use.

Note

If you use GPU, please setup CUDA first.

Prepare Dataset

You can make your own dataset without any efforts. In this tutorial, let’s use integrated datasets to start. If you want to make a new dataset, see MDPDataset.

d3rlpy provides suites of datasets for testing algorithms and research. See more documents at Datasets.

from d3rlpy.datasets import get_cartpole # CartPole-v0 dataset
from d3rlpy.datasets import get_pendulum # Pendulum-v0 dataset
from d3rlpy.datasets import get_pybullet # PyBullet task datasets
from d3rlpy.datasets import get_atari    # Atari 2600 task datasets

Here, we use the CartPole dataset to instantly check training results.

dataset, env = get_cartpole()

One interesting feature of d3rlpy is full compatibility with scikit-learn utilities. You can split dataset into a training dataset and a test dataset just like supervised learning as follows.

from sklearn.model_selection import train_test_split

train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

Setup Algorithm

There are many algorithms avaiable in d3rlpy. Since CartPole is the simple task, let’s start from DQN, which is the Q-learnig algorithm proposed as the first deep reinforcement learning algorithm.

from d3rlpy.algos import DQN

# if you don't use GPU, set use_gpu=False instead.
dqn = DQN(use_gpu=True)

See more algorithms and configurations at Algorithms.

Setup Metrics

Collecting evaluation metrics is important to train algorithms properly. In d3rlpy, the metrics is computed through scikit-learn style scorer functions.

from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer

# calculate metrics with test dataset
td_error = td_error_scorer(dqn, test_episodes)

Since evaluating algorithms without access to environment is still difficult, the algorithm can be directly evaluated with evaluate_on_environment function if the environment is available to interact.

from d3rlpy.metrics.scorer import evaluate_on_environment

# set environment in scorer function
evaluate_scorer = evaluate_on_environment(env)

# evaluate algorithm on the environment
rewards = evaluate_scorer(dqn)

See more metrics and configurations at Metrics.

Start Training

Now, you have all to start data-driven training.

dqn.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=10,
        scorers={
            'td_error': td_error_scorer,
            'value_scale': average_value_estimation_scorer,
            'environment': evaluate_scorer
        })

Then, you will see training progress in the console like below:

augmentation=[]
batch_size=32
bootstrap=False
dynamics=None
encoder_params={}
eps=0.00015
gamma=0.99
learning_rate=6.25e-05
n_augmentations=1
n_critics=1
n_frames=1
q_func_type=mean
scaler=None
share_encoder=False
target_update_interval=8000.0
use_batch_norm=True
use_gpu=None
observation_shape=(4,)
action_size=2
100%|███████████████████████████████████| 2490/2490 [00:24<00:00, 100.63it/s]
epoch=0 step=2490 value_loss=0.190237
epoch=0 step=2490 td_error=1.483964
epoch=0 step=2490 value_scale=1.241220
epoch=0 step=2490 environment=157.400000
100%|███████████████████████████████████| 2490/2490 [00:24<00:00, 100.63it/s]
.
.
.

See more about logging at Logging.

Once the training is done, your algorithm is ready to make decisions.

observation = env.reset()

# return actions based on the greedy-policy
action = dqn.predict([observation])[0]

# estimate action-values
value = dqn.predict_value([observation], [action])[0]

Save and Load

d3rlpy provides several ways to save trained models.

# save full parameters
dqn.save_model('dqn.pt')

# load full parameters
dqn2 = DQN()
dqn2.load_model('dqn.pt')

# save the greedy-policy as TorchScript
dqn.save_policy('policy.pt')

# save the greedy-policy as ONNX
dqn.save_policy('policy.onnx', as_onnx=True)

See more information at Save and Load.

API Reference

Algorithms

d3rlpy provides state-of-the-art data-driven deep reinforcement learning algorithms as well as online algorithms for the base implementations.

Continuous control algorithms

d3rlpy.algos.BC Behavior Cloning algorithm.
d3rlpy.algos.DDPG Deep Deterministic Policy Gradients algorithm.
d3rlpy.algos.TD3 Twin Delayed Deep Deterministic Policy Gradients algorithm.
d3rlpy.algos.SAC Soft Actor-Critic algorithm.
d3rlpy.algos.BCQ Batch-Constrained Q-learning algorithm.
d3rlpy.algos.BEAR Bootstrapping Error Accumulation Reduction algorithm.
d3rlpy.algos.CQL Conservative Q-Learning algorithm.
d3rlpy.algos.AWR Advantage-Weighted Regression algorithm.
d3rlpy.algos.AWAC Advantage Weighted Actor-Critic algorithm.

Discrete control algorithms

d3rlpy.algos.DiscreteBC Behavior Cloning algorithm for discrete control.
d3rlpy.algos.DQN Deep Q-Network algorithm.
d3rlpy.algos.DoubleDQN Double Deep Q-Network algorithm.
d3rlpy.algos.DiscreteBCQ Discrete version of Batch-Constrained Q-learning algorithm.
d3rlpy.algos.DiscreteCQL Discrete version of Conservative Q-Learning algorithm.
d3rlpy.algos.DiscreteCQL Discrete version of Conservative Q-Learning algorithm.
d3rlpy.algos.DiscreteAWR Discrete veriosn of Advantage-Weighted Regression algorithm.

Q Functions

d3rlpy provides various Q functions including state-of-the-arts, which are internally used in algorithm objects. You can switch Q functions by passing q_func_type argument at algorithm initialization.

from d3rlpy.algos import CQL

cql = CQL(q_func_type='qr') # use Quantile Regression Q function

The default Q function is mean approximator, which estimates expected scalar action-values. However, in recent advancements of deep reinforcement learning, the new type of action-value approximators has been proposed, which is called distributional Q functions.

Unlike the mean approximator, the distributional Q functions estimate distribution of action-values. This distributional approaches have shown consistently much stronger performance than the mean approximator.

Here is a list of available Q functions in the order of performance ascendingly. Currently, as a trade-off between performance and computational complexity, the higher performance requires the more expensive computational costs.

available Q functions
q_func_type reference
mean (default) N/A
qr Quantile Regression
iqn Implicit Quantile Network
fqf (experimental) Fully-parametrized Quantile Function

MDPDataset

d3rlpy provides useful dataset structure for data-driven deep reinforcement learning. In supervised learning, the training script iterates input data \(X\) and label data \(Y\). However, in reinforcement learning, mini-batches consist with sets of \((s_t, a_t, r_{t+1}, s_{t+1})\) and episode terminal flags. Converting a set of observations, actions, rewards and terminal flags into this tuples is boring and requires some codings.

Therefore, d3rlpy provides MDPDataset class which enables you to handle reinforcement learning datasets without any efforts.

from d3rlpy.dataset import MDPDataset

# 1000 steps of observations with shape of (100,)
observations = np.random.random((1000, 100))
# 1000 steps of actions with shape of (4,)
actions = np.random.random((1000, 4))
# 1000 steps of rewards
rewards = np.random.random(1000)
# 1000 steps of terminal flags
terminals = np.random.randint(2, size=1000)

dataset = MDPDataset(observations, actions, rewards, terminals)

# automatically splitted into d3rlpy.dataset.Episode objects
dataset.episodes

# each episode is also splitted into d3rlpy.dataset.Transition objects
episode = dataset.episodes[0]
episode[0].observation
episode[0].action
episode[0].next_reward
episode[0].next_observation
episode[0].terminal

# d3rlpy.dataset.Transition object has pointers to previous and next
# transitions like linked list.
transition = episode[0]
while transition.next_transition:
    transition = transition.next_transition

# save as HDF5
dataset.dump('dataset.h5')

# load from HDF5
new_dataset = MDPDataset.load('dataset.h5')
d3rlpy.dataset.MDPDataset Markov-Decision Process Dataset class.
d3rlpy.dataset.Episode Episode class.
d3rlpy.dataset.Transition Transition class.
d3rlpy.dataset.TransitionMiniBatch mini-batch of Transition objects.

Datasets

d3rlpy provides datasets for experimenting data-driven deep reinforcement learning algorithms.

d3rlpy.datasets.get_cartpole Returns cartpole dataset and environment.
d3rlpy.datasets.get_pendulum Returns pendulum dataset and environment.
d3rlpy.datasets.get_pybullet Returns pybullet dataset and envrironment.
d3rlpy.datasets.get_atari Returns atari dataset and envrironment.

Preprocessing

d3rlpy provides several preprocessors tightly incorporated with algorithms. Each preprocessor is implemented with PyTorch operation, which will be included in the model exported by save_policy method.

from d3rlpy.algos import CQL
from d3rlpy.dataset import MDPDataset

dataset = MDPDataset(...)

# choose from ['pixel', 'min_max', 'standard'] or None
cql = CQL(scaler='standard')

# scaler is fitted from the given episodes
cql.fit(dataset.episodes)

# preprocesing is included in TorchScript
cql.save_policy('policy.pt')

# you don't need to take care of preprocessing at production
policy = torch.jit.load('policy.pt')
action = policy(unpreprocessed_x)

You can also initialize scalers by yourself.

from d3rlpy.preprocessing import StandardScaler

scaler = StandardScaler(mean=..., std=...)

cql = CQL(scaler=scaler)
d3rlpy.preprocessing.PixelScaler Pixel normalization preprocessing.
d3rlpy.preprocessing.MinMaxScaler Min-Max normalization preprocessing.
d3rlpy.preprocessing.StandardScaler Standardization preprocessing.

Data Augmentation

d3rlpy provides data augmentation techniques tightly integrated with reinforcement learning algorithms.

  1. Kostrikov et al., Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels.
  2. Laskin et al., Reinforcement Learning with Augmented Data.
Efficient data augmentation potentially boosts algorithm performance significantly.
from d3rlpy.algos import DiscreteCQL

# choose data augmentation types
cql = DiscreteCQL(augmentation=['random_shift', 'intensity'],
                  n_augmentations=2)

You can also tune data augmentation parameters by yourself.

from d3rlpy.augmentation.image import RandomShift

random_shift = RandomShift(shift_size=10)

cql = DiscreteCQL(augmentation=[random_shift, 'intensity'],
                  n_augmentations=2)

Image Observation

d3rlpy.augmentation.image.RandomShift Random shift augmentation.
d3rlpy.augmentation.image.Cutout Cutout augmentation.
d3rlpy.augmentation.image.HorizontalFlip Horizontal flip augmentation.
d3rlpy.augmentation.image.VerticalFlip Vertical flip augmentation.
d3rlpy.augmentation.image.RandomRotation Random rotation augmentation.
d3rlpy.augmentation.image.Intensity Intensity augmentation.
d3rlpy.augmentation.image.ColorJitter Color Jitter augmentation.

Vector Observation

d3rlpy.augmentation.vector.SingleAmplitudeScaling Single Amplitude Scaling augmentation.
d3rlpy.augmentation.vector.MultipleAmplitudeScaling Multiple Amplitude Scaling augmentation.

Metrics

d3rlpy provides scoring functions without compromising scikit-learn compatibility. You can evaluate many metrics with test episodes during training.

from d3rlpy.datasets import get_cartpole
from d3rlpy.algos import DQN
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from d3rlpy.metrics.scorer import evaluate_on_environment
from sklearn.model_selection import train_test_split

dataset, env = get_cartpole()

train_episodes, test_episodes = train_test_split(dataset)

dqn = DQN()

dqn.fit(train_episodes,
        eval_episodes=test_episodes,
        scorers={
            'td_error': td_error_scorer,
            'value_scale': average_value_estimation_scorer,
            'environment': evaluate_on_environment(env)
        })

You can also use them with scikit-learn utilities.

from sklearn.model_selection import cross_validate

scores = cross_validate(dqn,
                        dataset,
                        scoring={
                            'td_error': td_error_scorer,
                            'environment': evaluate_on_environment(env)
                        })

Algorithms

d3rlpy.metrics.scorer.td_error_scorer Returns average TD error (in negative scale).
d3rlpy.metrics.scorer.discounted_sum_of_advantage_scorer Returns average of discounted sum of advantage (in negative scale).
d3rlpy.metrics.scorer.average_value_estimation_scorer Returns average value estimation (in negative scale).
d3rlpy.metrics.scorer.value_estimation_std_scorer Returns standard deviation of value estimation (in negative scale).
d3rlpy.metrics.scorer.initial_state_value_estimation_scorer Returns mean estimated action-values at the initial states.
d3rlpy.metrics.scorer.soft_opc_scorer Returns Soft Off-Policy Classification metrics.
d3rlpy.metrics.scorer.continuous_action_diff_scorer Returns squared difference of actions between algorithm and dataset.
d3rlpy.metrics.scorer.discrete_action_match_scorer Returns percentage of identical actions between algorithm and dataset.
d3rlpy.metrics.scorer.evaluate_on_environment Returns scorer function of evaluation on environment.
d3rlpy.metrics.comparer.compare_continuous_action_diff Returns scorer function of action difference between algorithms.
d3rlpy.metrics.comparer.compare_discrete_action_match Returns scorer function of action matches between algorithms.

Dynamics

d3rlpy.metrics.scorer.dynamics_observation_prediction_error_scorer Returns MSE of observation prediction (in negative scale).
d3rlpy.metrics.scorer.dynamics_reward_prediction_error_scorer Returns MSE of reward prediction (in negative scale).
d3rlpy.metrics.scorer.dynamics_prediction_variance_scorer Returns prediction variance of ensemble dynamics (in negative scale).

Save and Load

save_model and load_model

from d3rlpy.datasets import get_cartpole
from d3rlpy.algos import DQN

dataset, env = get_cartpole()

dqn = DQN()
dqn.fit(dataset.episodes, n_epochs=1)

# save entire model parameters.
dqn.save_model('model.pt')

# load entire model parameters.
dqn.load_model('model.pt')

save_model method saves all parameters including optimizer states, which is useful when checking all the outputs or re-training from snapshots.

from_json

It is very boring to set the same hyperparameters to initialize algorithms when loading model parameters. In d3rlpy, params.json is saved at the beggining of fit method, which includes all hyperparameters within the algorithm object. You can recreate algorithm objects from params.json via from_json method.

from d3rlpy.algos import DQN

dqn = DQN.from_json('d3rlpy_logs/<path-to-json>/params.json')

# ready to load
dqn.load_model('model.pt')

save_policy

save_policy method saves the only greedy-policy computation graph as TorchSciprt or ONNX. When save_policy method is called, the greedy-policy graph is constructed and traced via torch.jit.trace function.

from d3rlpy.datasets import get_cartpole
from d3rlpy.algos import DQN

dataset, env = get_cartpole()

dqn = DQN()
dqn.fit(dataset.episodes, n_epochs=1)

# save greedy-policy as TorchScript
dqn.save_policy('policy.pt')

# save greedy-policy as ONNX
dqn.save_policy('policy.onnx', as_onnx=True)
TorchScript

TorchScript is a optimizable graph expression provided by PyTorch. The saved policy can be loaded without any dependencies except PyTorch.

import torch

# load greedy-policy only with PyTorch
policy = torch.jit.load('policy.pt')

# returns greedy actions
actions = policy(torch.rand(32, 6))

This is especially useful when deploying the trained models to productions. The computation can be faster and you don’t need to install d3rlpy. Moreover, TorchScript model can be easily loaded even with C++, which will empower your robotics and embedding system projects.

#include <torch/script.h>

int main(int argc, char* argv[]) {
  torch::jit::script::Module module;
  try {
    module = torch::jit::load("policy.pt")
  } catch (const c10::Error& e) {
    return -1;
  }
  return 0;
}

You can get more information about TorchScript here.

ONNX

ONNX is an open format built to represent machine learning models. This is also useful when deploying the trained model to productions with various programming languages including Python, C++, JavaScript and more.

The following example is written with onnxruntime.

import onnxruntime as ort

# load ONNX policy via onnxruntime
ort_session = ort.InferenceSession('policy.onnx')

# observation
observation = np.random.rand(1, 6).astype(np.float32)

# returns greedy action
action = ort_session.run(None, {'input_0': observation})[0]

You can get more information about ONNX here.

Logging

d3rlpy algorithms automatically save model parameters and metrics under d3rlpy_logs directory.

from d3rlpy.datasets import get_cartpole
from d3rlpy.algos import DQN

dataset, env = get_cartpole()

dqn = DQN()

# metrics and parameters are saved in `d3rlpy_logs/DQN_YYYYMMDDHHmmss`
dqn.fit(dataset.episodes)

You can designate the directory.

# the directory will be `custom_logs/custom_YYYYMMDDHHmmss`
dqn.fit(dataset.episodes, logdir='custom_logs', experiment_name='custom')

If you want to disable all loggings, you can pass save_metrics=False.

dqn.fit(dataset.episodes, save_metrics=False)

TensorBoard

The same information is also automatically saved for tensorboard under runs directory. You can interactively visualize training metrics easily.

$ pip install tensorboard
$ tensorboard --logdir runs

This tensorboard logs can be disabled by passing tensorboard=False.

dqn.fit(dataset.episodes, tensorboard=False)

scikit-learn compatibility

d3rlpy provides complete scikit-learn compatible APIs.

train_test_split

d3rlpy.dataset.MDPDataset is compatible with splitting functions in scikit-learn.

from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics.scorer import td_error_scorer
from sklearn.model_selection import train_test_split

dataset, env = get_cartpole()

train_episodes, test_episodes = train_test_split(dataset, test_size=0.2)

dqn = DQN()
dqn.fit(train_episodes,
        eval_episodes=test_episodes,
        n_epochs=1,
        scorers={'td_error': td_error_scorer})

cross_validate

cross validation is also easily performed.

from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics import td_error_scorer
from sklearn.model_selection import cross_validate

dataset, env = get_cartpole()

dqn = DQN()

scores = cross_validate(dqn,
                        dataset,
                        scoring={'td_error': td_error_scorer},
                        fit_params={'n_epochs': 1})

GridSearchCV

You can also perform grid search to find good hyperparameters.

from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics import td_error_scorer
from sklearn.model_selection import GridSearchCV

dataset, env = get_cartpole()

dqn = DQN()

gscv = GridSearchCV(estimator=dqn,
                    param_grid={'learning_rate': [1e-4, 3e-4, 1e-3]},
                    scoring={'td_error': td_error_scorer},
                    refit=False)

gscv.fit(dataset.episodes, n_epochs=1)

parallel execution with multiple GPUs

Some scikit-learn utilities provide n_jobs option, which enable fitting process to run in paralell to boost productivity. Idealy, if you have multiple GPUs, the multiple processes use different GPUs for computational efficiency.

d3rlpy provides special device assignment mechanism to realize this.

from d3rlpy.algos import DQN
from d3rlpy.datasets import get_cartpole
from d3rlpy.metrics import td_error_scorer
from d3rlpy.context import parallel
from sklearn.model_selection import cross_validate

dataset, env = get_cartpole()

# enable GPU
dqn = DQN(use_gpu=True)

# automatically assign different GPUs for the 4 processes.
with parallel():
    scores = cross_validate(dqn,
                            dataset,
                            scoring={'td_error': td_error_scorer},
                            fit_params={'n_epochs': 1},
                            n_jobs=4)

If use_gpu=True is passed, d3rlpy internally manages GPU device id via d3rlpy.gpu.Device object. This object is designed for scikit-learn’s multi-process implementation that makes deep copies of the estimator object before dispatching. The Device object will increment its device id when deeply copied under the paralell context.

import copy
from d3rlpy.context import parallel
from d3rlpy.gpu import Device

device = Device(0)
# device.get_id() == 0

new_device = copy.deepcopy(device)
# new_device.get_id() == 0

with parallel():
    new_device = copy.deepcopy(device)
    # new_device.get_id() == 1
    # device.get_id() == 1

    new_device = copy.deepcopy(device)
    # if you have only 2 GPUs, it goes back to 0.
    # new_device.get_id() == 0
    # device.get_id() == 0

from d3rlpy.algos import DQN

dqn = DQN(use_gpu=Device(0)) # assign id=0
dqn = DQN(use_gpu=Device(1)) # assign id=1

Online Training

d3rlpy provides not only offline training, but also online training utilities. Despite being designed for offline training algorithms, d3rlpy is flexible enough to be trained in an online manner with a few more utilities.

import gym

from d3rlpy.algos import DQN
from d3rlpy.online.buffers import ReplayBuffer
from d3rlpy.online.explorers import LinearDecayEpsilonGreedy

# setup environment
env = gym.make('CartPole-v0')
eval_env = gym.make('CartPole-v0')

# setup algorithm
dqn = DQN(batch_size=32,
          learning_rate=2.5e-4,
          target_update_interval=100,
          use_gpu=True)

# setup replay buffer
buffer = ReplayBuffer(maxlen=1000000, env=env)

# setup explorers
explorer = LinearDecayEpsilonGreedy(start_epsilon=1.0,
                                    end_epsilon=0.1,
                                    duration=10000)

# start training
dqn.fit_online(env,
               buffer,
               explorer=explorer, # you don't need this with probablistic policy algorithms
               eval_env=eval_env,
               n_epochs=30,
               n_steps_per_epoch=1000,
               n_updates_per_epoch=100)

Replay Buffer

d3rlpy.online.buffers.ReplayBuffer Standard Replay Buffer.

Explorers

d3rlpy.online.explorers.LinearDecayEpsilonGreedy \(\epsilon\)-greedy explorer with linear decay schedule.
d3rlpy.online.explorers.NormalNoise Normal noise explorer.

Iterators

d3rlpy.online.iterators.train Start training loop of online deep reinforcement learning.

Model-based Data Augmentation

d3rlpy provides model-based reinforcement learning algorithms. In d3rlpy, model-based algorithms are viewed as data augmentation techniques, which can boost performance potentially beyond the model-free algorithms.

from d3rlpy.datasets import get_pendulum
from d3rlpy.dynamics import MOPO
from d3rlpy.metrics.scorer import dynamics_observation_prediction_error_scorer
from d3rlpy.metrics.scorer import dynamics_reward_prediction_error_scorer
from d3rlpy.metrics.scorer import dynamics_prediction_variance_scorer
from sklearn.model_selection import train_test_split

dataset, _ = get_pendulum()

train_episodes, test_episodes = train_test_split(dataset)

mopo = MOPO(learning_rate=1e-4, use_gpu=True)

# same as algorithms
mopo.fit(train_episodes,
         eval_episodes=test_episodes,
         n_epochs=100,
         scorers={
            'observation_error': dynamics_observation_prediction_error_scorer,
            'reward_error': dynamics_reward_prediction_error_scorer,
            'variance': dynamics_prediction_variance_scorer,
         })

Pick the best model based on evaluation metrics.

from d3rlpy.dynamics import MOPO
from d3rlpy.algos import CQL

# load trained dynamics model
mopo = MOPO.from_json('<path-to-params.json>/params.json')
mopo.load_model('<path-to-model>/model_xx.pt')
mopo.n_transitions = 400 # tunable parameter
mopo.horizon = 5 # tunable parameter
mopo.lam = 1.0 # tunable parameter

# give mopo as dynamics argument.
cql = CQL(dynamics=mopo)

If you pass a dynamics model to algorithms, new transitions are generated at the beginning of every epoch.

d3rlpy.dynamics.mopo.MOPO Model-based Offline Policy Optimization.

Installation

Install d3rlpy

Install via PyPI

pip is a recommended way to install d3rlpy:

$ pip install d3rlpy

Install from source

You can also install via GitHub repository:

$ git clone https://github.com/takuseno/d3rlpy
$ cd d3rlpy
$ pip install Cython numpy # if you have not installed them.
$ pip install -e .

License

MIT License

Copyright (c) 2020 Takuma Seno

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

Indices and tables