d3rlpy.models.IQNQFunctionFactory

class d3rlpy.models.IQNQFunctionFactory(share_encoder=False, n_quantiles=64, n_greedy_quantiles=32, embed_size=64)[source]

Implicit Quantile Network Q function factory class.

References

Parameters
  • share_encoder (bool) – flag to share encoder over multiple Q functions.

  • n_quantiles (int) – the number of quantiles.

  • n_greedy_quantiles (int) – the number of quantiles for inference.

  • embed_size (int) – the embedding size.

Return type

None

Methods

create_continuous(encoder, hidden_size)[source]

Returns PyTorch’s Q function module.

Parameters
  • encoder (d3rlpy.models.torch.encoders.EncoderWithAction) – Encoder module that processes the observation and action to obtain feature representations.

  • hidden_size (int) – Dimension of encoder output.

Returns

Tuple of continuous Q function and its forwarder.

Return type

Tuple[d3rlpy.models.torch.q_functions.iqn_q_function.ContinuousIQNQFunction, d3rlpy.models.torch.q_functions.iqn_q_function.ContinuousIQNQFunctionForwarder]

create_discrete(encoder, hidden_size, action_size)[source]

Returns PyTorch’s Q function module.

Parameters
  • encoder (d3rlpy.models.torch.encoders.Encoder) – Encoder that processes the observation to obtain feature representations.

  • hidden_size (int) – Dimension of encoder output.

  • action_size (int) – Dimension of discrete action-space.

Returns

Tuple of discrete Q function and its forwarder.

Return type

Tuple[d3rlpy.models.torch.q_functions.iqn_q_function.DiscreteIQNQFunction, d3rlpy.models.torch.q_functions.iqn_q_function.DiscreteIQNQFunctionForwarder]

classmethod deserialize(serialized_config)
Parameters

serialized_config (str) –

Return type

d3rlpy.serializable_config.TConfig

classmethod deserialize_from_dict(dict_config)
Parameters

dict_config (Dict[str, Any]) –

Return type

d3rlpy.serializable_config.TConfig

classmethod deserialize_from_file(path)
Parameters

path (str) –

Return type

d3rlpy.serializable_config.TConfig

classmethod from_dict(kvs, *, infer_missing=False)
Parameters

kvs (Optional[Union[dict, list, str, int, float, bool]]) –

Return type

dataclasses_json.api.A

classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)
Parameters

s (Union[str, bytes, bytearray]) –

Return type

dataclasses_json.api.A

static get_type()[source]

Returns Q function type.

Returns

Q function type.

Return type

str

classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)
Parameters
  • infer_missing (bool) –

  • many (bool) –

  • partial (bool) –

Return type

dataclasses_json.mm.SchemaF[dataclasses_json.api.A]

serialize()
Return type

str

serialize_to_dict()
Return type

Dict[str, Any]

to_dict(encode_json=False)
Return type

Dict[str, Optional[Union[dict, list, str, int, float, bool]]]

to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)
Parameters
  • skipkeys (bool) –

  • ensure_ascii (bool) –

  • check_circular (bool) –

  • allow_nan (bool) –

  • indent (Optional[Union[int, str]]) –

  • separators (Optional[Tuple[str, str]]) –

  • default (Optional[Callable]) –

  • sort_keys (bool) –

Return type

str

Attributes

embed_size: int = 64
n_greedy_quantiles: int = 32
n_quantiles: int = 64
share_encoder: bool = False