d3rlpy.augmentation.pipeline.DrQPipeline

class d3rlpy.augmentation.pipeline.DrQPipeline(augmentations=None, n_mean=1)[source]

Data-reguralized Q augmentation pipeline.

References

Parameters
  • augmentations (list(d3rlpy.augmentation.base.Augmentation or str)) – list of augmentations or augmentation types.

  • n_mean (int) – the number of computations to average

Methods

append(augmentation)

Append augmentation to pipeline.

Parameters

augmentation (d3rlpy.augmentation.base.Augmentation) – augmentation.

Return type

None

get_augmentation_params()

Returns augmentation parameters.

Parameters

deep – flag to deeply copy objects.

Returns

list of augmentation parameters.

Return type

List[Dict[str, Any]]

get_augmentation_types()

Returns augmentation types.

Returns

list of augmentation types.

Return type

List[str]

get_params(deep=False)[source]

Returns pipeline parameters.

Returns

piple parameters.

Parameters

deep (bool) –

Return type

Dict[str, Any]

process(func, inputs, targets)[source]

Runs a given function while augmenting inputs.

Parameters
  • func (Callable[[..], torch.Tensor]) – function to compute.

  • inputs (Dict[str, torch.Tensor]) – inputs to the func.

  • target – list of argument names to augment.

  • targets (List[str]) –

Returns

the computation result.

Return type

torch.Tensor

transform(x)

Returns observation processed by all augmentations.

Parameters

x (torch.Tensor) – observation tensor.

Returns

processed observation tensor.

Return type

torch.Tensor

Attributes

augmentations