Policy

AlphaZeroPolicy

class lzero.policy.alphazero.AlphaZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: Policy

Overview:

The policy class for AlphaZero.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(obs: Dict, temperature: float = 1) Dict[str, Tensor][source]
Overview:

The forward function for collecting data in collect mode. Use real env to execute MCTS search.

Parameters:
  • obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.

  • temperature (-) – The temperature for MCTS search.

Returns:

The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.

Return type:

  • output (Dict[str, torch.Tensor])

_forward_eval(obs: Dict) Dict[str, Tensor][source]
Overview:

The forward function for evaluating the current policy in eval mode, similar to self._forward_collect.

Parameters:

obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.

Returns:

The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.

Return type:

  • output (Dict[str, torch.Tensor])

_forward_learn(inputs: Dict[str, Tensor]) Dict[str, float][source]
Overview:

Policy forward function of learn mode (training policy and updating parameters). Forward means that the policy inputs some training batch data from the replay buffer and then returns the output result, including various training information such as loss value, policy entropy, q value, priority, and so on. This method is left to be implemented by the subclass, and more arguments can be added in data item if necessary.

Parameters:

data (-) – The input data used for policy forward, including a batch of training samples. For each element in list, the key of the dict is the name of data items and the value is the corresponding data. Usually, in the _forward_learn method, data should be stacked in the batch dimension by some utility functions such as default_preprocess_learn.

Returns:

The training information of policy forward, including some metrics for monitoring training such as loss, priority, q value, policy entropy, and some data for next step training such as priority. Note the output data item should be Python native scalar rather than PyTorch tensor, which is convenient for the outside to use.

Return type:

  • output (Dict[int, Any])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_simulation_env()[source]
_get_train_sample(data)[source]
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Initialize the learn mode of policy, including related attributes and modules. This method will be called in __init__ method if learn field is in enable_field. Almost different policies have its own learn mode, so this method must be overrided in subclass.

Note

For the member variables that need to be saved and loaded, please refer to the _state_dict_learn and _load_state_dict_learn methods.

Note

For the member variables that need to be monitored, please refer to the _monitor_vars_learn method.

Note

If you want to set some spacial member variables in _init_learn method, you’d better name them with prefix _learn_ to avoid conflict with other modes, such as self._learn_attr1.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – The dict of policy learn state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_policy_value_fn(env: Env) Tuple[Dict[int, ndarray], float][source]
_process_transition(obs: Dict, model_output: Dict[str, Tensor], timestep: namedtuple) Dict[source]
Overview:

Generate the dict type transition (one timestep) data from policy learning.

_reset_collect(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different environments/episodes in collecting in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_reset_eval(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different environments/episodes in evaluation in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any]
Overview:

Return the state_dict of learn mode, usually including model and optimizer.

Returns:

The dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'batch_size': 256, 'collector_env_num': 8, 'cuda': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'grad_clip_value': 10, 'gumbel_algo': False, 'learning_rate': 0.2, 'lr_piecewise_constant_decay': True, 'manual_temperature_decay': False, 'mcts': {'max_moves': 512, 'num_simulations': 50, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25}, 'model': {'num_channels': 32, 'num_res_blocks': 1, 'observation_shape': (3, 6, 6)}, 'momentum': 0.9, 'multi_gpu': False, 'optim_type': 'SGD', 'other': {'replay_buffer': {'replay_buffer_size': 1000000, 'save_episode': False}}, 'replay_ratio': 0.25, 'sampled_algo': False, 'tensor_float_32': False, 'threshold_training_steps_for_final_lr': 500000, 'threshold_training_steps_for_final_temperature': 100000, 'torch_compile': False, 'update_per_collect': None, 'value_weight': 1.0, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting for demonstration.

Returns:

The model type used in this algorithm, which is registered in ModelRegistry. - import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_type (str)

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}

MuZeroPolicy

class lzero.policy.muzero.MuZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: Policy

Overview:
if self._cfg.model.model_type in [“conv”, “mlp”]:

The policy class for MuZero.

if self._cfg.model.model_type == [“conv_context”, “mlp_context”]:

The policy class for MuZero w/ Context, a variant of MuZero. This variant retains the same training settings as MuZero but diverges during inference by employing a k-step recursively predicted latent representation at the root node, proposed in the UniZero paper https://arxiv.org/abs/2406.10667.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None) Dict[source]
Overview:

The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • temperature (-) – The temperature of the policy.

  • to_play (-) – The player to play.

  • epsilon (-) – The epsilon of the eps greedy exploration.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • temperature: \((1, )\).

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • epsilon: \((1, )\).

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_eval(data: Tensor, action_mask: list, to_play: int = -1, ready_env_id: array = None) Dict[source]
Overview:

The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_learn(data: Tuple[Tensor]) Dict[str, float | int][source]
Overview:

The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.

Parameters:

data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.

Returns:

The information dict to be logged, which contains current learning loss and learning statistics.

Return type:

  • info_dict (Dict[str, Union[float, int]])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_target_obs_index_in_step_k(step)[source]
Overview:

Get the begin index and end index of the target obs in step k.

Parameters:

step (-) – The current step k.

Returns:

The begin index of the target obs in step k. - end_index (int): The end index of the target obs in step k.

Return type:

  • beg_index (int)

Examples

>>> self._cfg.model.model_type = 'conv'
>>> self._cfg.model.image_channel = 3
>>> self._cfg.model.frame_stack_num = 4
>>> self._get_target_obs_index_in_step_k(0)
>>> (0, 12)
_get_train_sample(data)[source]
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Learn mode init method. Called by self.__init__. Initialize the learn model, optimizer and MCTS utils.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None[source]
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – The dict of policy learn state saved before.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_process_transition(obs, policy_output, timestep)[source]
Overview:

Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.

Parameters:
  • obs (-) – The observation of the current timestep.

  • policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.

  • timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.

Returns:

The processed transition data of the current timestep.

Return type:

  • transition (Dict[str, torch.Tensor])

_reset_collect(data_id: List[int] | None = None) None[source]
Overview:

Reset the observation and action for the collector environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_eval(data_id: List[int] | None = None) None[source]
Overview:

Reset the observation and action for the evaluator environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any][source]
Overview:

Return the state_dict of learn mode, usually including model, target_model and optimizer.

Returns:

The dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'action_type': 'fixed_action_space', 'analysis_dormant_ratio': False, 'analysis_sim_norm': False, 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'cal_dormant_ratio': False, 'collect_with_pure_policy': False, 'collector_env_num': 8, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'learning_rate': 0.2, 'lr_piecewise_constant_decay': True, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'analysis_dormant_ratio': False, 'analysis_sim_norm': False, 'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'frame_stack_num': 1, 'harmony_balance': False, 'image_channel': 1, 'model_type': 'conv', 'norm_type': 'BN', 'num_channels': 64, 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': False, 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'SGD', 'policy_entropy_loss_weight': 0, 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reuse_search': False, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': False, 'ssl_loss_weight': 0, 'target_update_freq': 100, 'target_update_freq_for_intrinsic_reward': 1000, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': False, 'use_rnd_model': False, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting for demonstration.

Returns:

model name and model import_names.
  • model_type (str): The model type used in this algorithm, which is registered in ModelRegistry.

  • import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_info (Tuple[str, List[str]])

Note

The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For MuZero, lzero.model.muzero_model.MuZeroModel

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}

EfficientZeroPolicy

class lzero.policy.efficientzero.EfficientZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: MuZeroPolicy

Overview:

The policy class for EfficientZero proposed in the paper https://arxiv.org/abs/2111.00210.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None)[source]
Overview:

The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • temperature (-) – The temperature of the policy.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • temperature: \((1, )\).

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_eval(data: Tensor, action_mask: list, to_play: -1, ready_env_id: array = None)[source]
Overview:

The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_learn(data: Tensor) Dict[str, float | int][source]
Overview:

The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.

Parameters:

data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.

Returns:

The information dict to be logged, which contains current learning loss and learning statistics.

Return type:

  • info_dict (Dict[str, Union[float, int]])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_target_obs_index_in_step_k(step)
Overview:

Get the begin index and end index of the target obs in step k.

Parameters:

step (-) – The current step k.

Returns:

The begin index of the target obs in step k. - end_index (int): The end index of the target obs in step k.

Return type:

  • beg_index (int)

Examples

>>> self._cfg.model.model_type = 'conv'
>>> self._cfg.model.image_channel = 3
>>> self._cfg.model.frame_stack_num = 4
>>> self._get_target_obs_index_in_step_k(0)
>>> (0, 12)
_get_train_sample(data)[source]
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Learn mode init method. Called by self.__init__. Initialize the learn model, optimizer and MCTS utils.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None[source]
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – the dict of policy learn state saved before.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_process_transition(obs, policy_output, timestep)[source]
Overview:

Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.

Parameters:
  • obs (-) – The observation of the current timestep.

  • policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.

  • timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.

Returns:

The processed transition data of the current timestep.

Return type:

  • transition (Dict[str, torch.Tensor])

_reset_collect(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the collector environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_eval(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the evaluator environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any][source]
Overview:

Return the state_dict of learn mode, usually including model and optimizer.

Returns:

the dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'action_type': 'fixed_action_space', 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collect_with_pure_policy': False, 'collector_env_num': 8, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'learning_rate': 0.2, 'lr_piecewise_constant_decay': True, 'lstm_horizon_len': 5, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'frame_stack_num': 1, 'image_channel': 1, 'lstm_hidden_size': 512, 'model_type': 'conv', 'norm_type': 'BN', 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': True, 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'SGD', 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reuse_search': False, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': False, 'ssl_loss_weight': 2, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': False, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting.

Returns:

model name and model import_names.
  • model_type (str): The model type used in this algorithm, which is registered in ModelRegistry.

  • import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_info (Tuple[str, List[str]])

Note

The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For EfficientZero, lzero.model.efficientzero_model.EfficientZeroModel

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}

Gumbel AlphaZeroPolicy

class lzero.policy.gumbel_alphazero.GumbelAlphaZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: Policy

Overview:

The policy class for GumbelAlphaZero.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(obs: Dict, temperature: float = 1) Dict[str, Tensor][source]
Overview:

The forward function for collecting data in collect mode. Use real env to execute MCTS search.

Parameters:
  • obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.

  • temperature (-) – The temperature for MCTS search.

Returns:

The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.

Return type:

  • output (Dict[str, torch.Tensor])

_forward_eval(obs: Dict) Dict[str, Tensor][source]
Overview:

The forward function for evaluating the current policy in eval mode, similar to self._forward_collect.

Parameters:

obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.

Returns:

The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.

Return type:

  • output (Dict[str, torch.Tensor])

_forward_learn(inputs: Dict[str, Tensor]) Dict[str, float][source]
Overview:

Policy forward function of learn mode (training policy and updating parameters). Forward means that the policy inputs some training batch data from the replay buffer and then returns the output result, including various training information such as loss value, policy entropy, q value, priority, and so on. This method is left to be implemented by the subclass, and more arguments can be added in data item if necessary.

Parameters:

data (-) – The input data used for policy forward, including a batch of training samples. For each element in list, the key of the dict is the name of data items and the value is the corresponding data. Usually, in the _forward_learn method, data should be stacked in the batch dimension by some utility functions such as default_preprocess_learn.

Returns:

The training information of policy forward, including some metrics for monitoring training such as loss, priority, q value, policy entropy, and some data for next step training such as priority. Note the output data item should be Python native scalar rather than PyTorch tensor, which is convenient for the outside to use.

Return type:

  • output (Dict[int, Any])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_simulation_env()[source]
_get_train_sample(data)[source]
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Initialize the learn mode of policy, including related attributes and modules. This method will be called in __init__ method if learn field is in enable_field. Almost different policies have its own learn mode, so this method must be overrided in subclass.

Note

For the member variables that need to be saved and loaded, please refer to the _state_dict_learn and _load_state_dict_learn methods.

Note

For the member variables that need to be monitored, please refer to the _monitor_vars_learn method.

Note

If you want to set some spacial member variables in _init_learn method, you’d better name them with prefix _learn_ to avoid conflict with other modes, such as self._learn_attr1.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – The dict of policy learn state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_policy_value_fn(env: Env) Tuple[Dict[int, ndarray], float][source]
_process_transition(obs: Dict, model_output: Dict[str, Tensor], timestep: namedtuple) Dict[source]
Overview:

Generate the dict type transition (one timestep) data from policy learning.

_reset_collect(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different environments/episodes in collecting in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_reset_eval(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different environments/episodes in evaluation in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any]
Overview:

Return the state_dict of learn mode, usually including model and optimizer.

Returns:

The dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'batch_size': 256, 'collector_env_num': 8, 'cuda': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'grad_clip_value': 10, 'learning_rate': 0.2, 'lr_piecewise_constant_decay': True, 'manual_temperature_decay': False, 'mcts': {'action_space_size': 9, 'continuous_action_space': False, 'gumbel_rng': 0.0, 'gumbel_scale': 10.0, 'legal_actions': None, 'max_moves': 512, 'max_num_considered_actions': 6, 'maxvisit_init': 50, 'num_of_sampled_actions': 2, 'num_simulations': 50, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'value_scale': 0.1}, 'mcts_ctree': True, 'model': {'num_channels': 32, 'num_res_blocks': 1, 'observation_shape': (3, 6, 6)}, 'momentum': 0.9, 'optim_type': 'SGD', 'other': {'replay_buffer': {'replay_buffer_size': 1000000, 'save_episode': False}}, 'replay_ratio': 0.25, 'sampled_algo': False, 'tensor_float_32': False, 'threshold_training_steps_for_final_lr': 500000, 'threshold_training_steps_for_final_temperature': 100000, 'torch_compile': False, 'type': 'gumbel_alphazero', 'update_per_collect': None, 'value_weight': 1.0, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting for demonstration.

Returns:

The model type used in this algorithm, which is registered in ModelRegistry. - import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_type (str)

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}

Gumbel MuZeroPolicy

class lzero.policy.gumbel_muzero.GumbelMuZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: MuZeroPolicy

Overview:

The policy class for Gumbel MuZero proposed in the paper https://openreview.net/forum?id=bERaNdoegnO.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None) Dict[source]
Overview:

The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • temperature (-) – The temperature of the policy.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • temperature: \((1, )\).

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, roots_completed_value, improved_policy_probs, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_eval(data: Tensor, action_mask: list, to_play: int = -1, ready_env_id: array = None) Dict[source]
Overview:

The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_learn(data: Tensor) Dict[str, float | int][source]
Overview:

The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.

Parameters:

data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.

Returns:

The information dict to be logged, which contains current learning loss and learning statistics.

Return type:

  • info_dict (Dict[str, Union[float, int]])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_target_obs_index_in_step_k(step)
Overview:

Get the begin index and end index of the target obs in step k.

Parameters:

step (-) – The current step k.

Returns:

The begin index of the target obs in step k. - end_index (int): The end index of the target obs in step k.

Return type:

  • beg_index (int)

Examples

>>> self._cfg.model.model_type = 'conv'
>>> self._cfg.model.image_channel = 3
>>> self._cfg.model.frame_stack_num = 4
>>> self._get_target_obs_index_in_step_k(0)
>>> (0, 12)
_get_train_sample(data)[source]
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Learn mode init method. Called by self.__init__. Initialize the learn model, optimizer and MCTS utils.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None[source]
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – The dict of policy learn state saved before.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_process_transition(obs, policy_output, timestep)[source]
Overview:

Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.

Parameters:
  • obs (-) – The observation of the current timestep.

  • policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.

  • timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.

Returns:

The processed transition data of the current timestep.

Return type:

  • transition (Dict[str, torch.Tensor])

_reset_collect(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the collector environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_eval(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the evaluator environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any][source]
Overview:

Return the state_dict of learn mode, usually including model, target_model and optimizer.

Returns:

The dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'action_type': 'fixed_action_space', 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collector_env_num': 8, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': True, 'ignore_done': False, 'learning_rate': 0.2, 'lr_piecewise_constant_decay': True, 'manual_temperature_decay': False, 'max_num_considered_actions': 4, 'mcts_ctree': True, 'model': {'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'frame_stack_num': 1, 'image_channel': 1, 'model_type': 'conv', 'norm_type': 'BN', 'num_channels': 64, 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': False, 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'SGD', 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': False, 'ssl_loss_weight': 0, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': True, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting for demonstration.

Returns:

model name and model import_names.
  • model_type (str): The model type used in this algorithm, which is registered in ModelRegistry.

  • import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_info (Tuple[str, List[str]])

Note

The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For MuZero, lzero.model.muzero_model.MuZeroModel

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}

Sampled AlphaZeroPolicy

class lzero.policy.sampled_alphazero.SampledAlphaZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: Policy

Overview:

The policy class for Sampled AlphaZero.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_calculate_policy_loss_disc(policy_probs: Tensor, target_policy: Tensor, target_sampled_actions: Tensor, valid_action_lengths: Tensor) Tensor[source]
_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(obs: Dict, temperature: float = 1) Dict[str, Tensor][source]
Overview:

The forward function for collecting data in collect mode. Use real env to execute MCTS search.

Parameters:
  • obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.

  • temperature (-) – The temperature for MCTS search.

Returns:

The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.

Return type:

  • output (Dict[str, torch.Tensor])

_forward_eval(obs: Dict) Dict[str, Tensor][source]
Overview:

The forward function for evaluating the current policy in eval mode, similar to self._forward_collect.

Parameters:

obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.

Returns:

The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.

Return type:

  • output (Dict[str, torch.Tensor])

_forward_learn(inputs: Dict[str, Tensor]) Dict[str, float][source]
Overview:

Policy forward function of learn mode (training policy and updating parameters). Forward means that the policy inputs some training batch data from the replay buffer and then returns the output result, including various training information such as loss value, policy entropy, q value, priority, and so on. This method is left to be implemented by the subclass, and more arguments can be added in data item if necessary.

Parameters:

data (-) – The input data used for policy forward, including a batch of training samples. For each element in list, the key of the dict is the name of data items and the value is the corresponding data. Usually, in the _forward_learn method, data should be stacked in the batch dimension by some utility functions such as default_preprocess_learn.

Returns:

The training information of policy forward, including some metrics for monitoring training such as loss, priority, q value, policy entropy, and some data for next step training such as priority. Note the output data item should be Python native scalar rather than PyTorch tensor, which is convenient for the outside to use.

Return type:

  • output (Dict[int, Any])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_simulation_env()[source]
_get_train_sample(data)[source]
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Initialize the learn mode of policy, including related attributes and modules. This method will be called in __init__ method if learn field is in enable_field. Almost different policies have its own learn mode, so this method must be overrided in subclass.

Note

For the member variables that need to be saved and loaded, please refer to the _state_dict_learn and _load_state_dict_learn methods.

Note

For the member variables that need to be monitored, please refer to the _monitor_vars_learn method.

Note

If you want to set some spacial member variables in _init_learn method, you’d better name them with prefix _learn_ to avoid conflict with other modes, such as self._learn_attr1.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – The dict of policy learn state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_policy_value_func(environment: Environment) Tuple[Dict[int, ndarray], float][source]
_process_transition(obs: Dict, model_output: Dict[str, Tensor], timestep: namedtuple) Dict[source]
Overview:

Generate the dict type transition (one timestep) data from policy learning.

_reset_collect(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different environments/episodes in collecting in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_reset_eval(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different environments/episodes in evaluation in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any]
Overview:

Return the state_dict of learn mode, usually including model and optimizer.

Returns:

The dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'batch_size': 256, 'collector_env_num': 8, 'cuda': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'grad_clip_value': 10, 'learning_rate': 0.2, 'lr_piecewise_constant_decay': True, 'manual_temperature_decay': False, 'mcts': {'action_space_size': 9, 'continuous_action_space': False, 'legal_actions': None, 'max_moves': 512, 'num_of_sampled_actions': 2, 'num_simulations': 50, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25}, 'mcts_ctree': True, 'model': {'num_channels': 32, 'num_res_blocks': 1, 'observation_shape': (3, 6, 6)}, 'momentum': 0.9, 'normalize_prob_of_sampled_actions': False, 'optim_type': 'SGD', 'other': {'replay_buffer': {'replay_buffer_size': 1000000, 'save_episode': False}}, 'policy_loss_type': 'cross_entropy', 'replay_ratio': 0.25, 'sampled_algo': False, 'tensor_float_32': False, 'threshold_training_steps_for_final_lr': 500000, 'threshold_training_steps_for_final_temperature': 100000, 'torch_compile': False, 'type': 'alphazero', 'update_per_collect': None, 'value_weight': 1.0, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting for demonstration.

Returns:

The model type used in this algorithm, which is registered in ModelRegistry. - import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_type (str)

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}

Sampled MuZeroPolicy

class lzero.policy.sampled_muzero.SampledMuZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: MuZeroPolicy

Overview:

The policy class for Sampled MuZero proposed in the paper https://arxiv.org/abs/2104.06303.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_calculate_policy_loss_cont(policy_loss: Tensor, policy_logits: Tensor, target_policy: Tensor, mask_batch: Tensor, child_sampled_actions_batch: Tensor, unroll_step: int) Tuple[Tensor][source]
Overview:

Calculate the policy loss for continuous action space.

Parameters:
  • policy_loss (-) – The policy loss tensor.

  • policy_logits (-) – The policy logits tensor.

  • target_policy (-) – The target policy tensor.

  • mask_batch (-) – The mask tensor.

  • child_sampled_actions_batch (-) – The child sampled actions tensor.

  • unroll_step (-) – The unroll step.

Returns:

The policy loss tensor. - policy_entropy (torch.Tensor): The policy entropy tensor. - policy_entropy_loss (torch.Tensor): The policy entropy loss tensor. - target_policy_entropy (torch.Tensor): The target policy entropy tensor. - target_sampled_actions (torch.Tensor): The target sampled actions tensor. - mu (torch.Tensor): The mu tensor. - sigma (torch.Tensor): The sigma tensor.

Return type:

  • policy_loss (torch.Tensor)

_calculate_policy_loss_disc(policy_loss: Tensor, policy_logits: Tensor, target_policy: Tensor, mask_batch: Tensor, child_sampled_actions_batch: Tensor, unroll_step: int) Tuple[Tensor][source]
Overview:

Calculate the policy loss for discrete action space.

Parameters:
  • policy_loss (-) – The policy loss tensor.

  • policy_logits (-) – The policy logits tensor.

  • target_policy (-) – The target policy tensor.

  • mask_batch (-) – The mask tensor.

  • child_sampled_actions_batch (-) – The child sampled actions tensor.

  • unroll_step (-) – The unroll step.

Returns:

The policy loss tensor. - policy_entropy (torch.Tensor): The policy entropy tensor. - policy_entropy_loss (torch.Tensor): The policy entropy loss tensor. - target_policy_entropy (torch.Tensor): The target policy entropy tensor. - target_sampled_actions (torch.Tensor): The target sampled actions tensor.

Return type:

  • policy_loss (torch.Tensor)

_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(data: Tensor, action_mask: list = None, temperature: ndarray = 1, to_play=-1, epsilon: float = 0.25, ready_env_id: array = None)[source]
Overview:

The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • temperature (-) – The temperature of the policy.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • temperature: \((1, )\).

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_eval(data: Tensor, action_mask: list, to_play: -1, ready_env_id: array = None)[source]
Overview:

The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_learn(data: Tensor) Dict[str, float | int][source]
Overview:

The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.

Parameters:

data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.

Returns:

The information dict to be logged, which contains current learning loss and learning statistics.

Return type:

  • info_dict (Dict[str, Union[float, int]])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_target_obs_index_in_step_k(step)
Overview:

Get the begin index and end index of the target obs in step k.

Parameters:

step (-) – The current step k.

Returns:

The begin index of the target obs in step k. - end_index (int): The end index of the target obs in step k.

Return type:

  • beg_index (int)

Examples

>>> self._cfg.model.model_type = 'conv'
>>> self._cfg.model.image_channel = 3
>>> self._cfg.model.frame_stack_num = 4
>>> self._get_target_obs_index_in_step_k(0)
>>> (0, 12)
_get_train_sample(data)
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Learn mode init method. Called by self.__init__. Initialize the learn model, optimizer and MCTS utils.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – The dict of policy learn state saved before.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_process_transition(obs, policy_output, timestep)
Overview:

Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.

Parameters:
  • obs (-) – The observation of the current timestep.

  • policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.

  • timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.

Returns:

The processed transition data of the current timestep.

Return type:

  • transition (Dict[str, torch.Tensor])

_reset_collect(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the collector environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_eval(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the evaluator environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any]
Overview:

Return the state_dict of learn mode, usually including model, target_model and optimizer.

Returns:

The dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'action_type': 'fixed_action_space', 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collector_env_num': 8, 'cos_lr_scheduler': False, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'init_w': 0.003, 'learning_rate': 0.0001, 'lr_piecewise_constant_decay': False, 'lstm_horizon_len': 5, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'action_space_size': 6, 'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'fixed_sigma_value': 0.3, 'frame_stack_num': 1, 'image_channel': 1, 'lstm_hidden_size': 512, 'model_type': 'conv', 'norm_type': 'LN', 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': True, 'sigma_type': 'conditioned', 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'normalize_prob_of_sampled_actions': False, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'AdamW', 'policy_entropy_loss_weight': 0, 'policy_loss_type': 'cross_entropy', 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': True, 'ssl_loss_weight': 2, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': False, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting.

Returns:

model name and model import_names.
  • model_type (str): The model type used in this algorithm, which is registered in ModelRegistry.

  • import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_info (Tuple[str, List[str]])

Note

The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For Sampled MuZero, lzero.model.sampled_muzero_model.SampledMuZeroModel

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}

Sampled EfficientZeroPolicy

class lzero.policy.sampled_efficientzero.SampledEfficientZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: MuZeroPolicy

Overview:

The policy class for Sampled EfficientZero proposed in the paper https://arxiv.org/abs/2104.06303.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_calculate_policy_loss_cont(policy_loss: Tensor, policy_logits: Tensor, target_policy: Tensor, mask_batch: Tensor, child_sampled_actions_batch: Tensor, unroll_step: int) Tuple[Tensor][source]
Overview:

Calculate the policy loss for continuous action space.

Parameters:
  • policy_loss (-) – The policy loss tensor.

  • policy_logits (-) – The policy logits tensor.

  • target_policy (-) – The target policy tensor.

  • mask_batch (-) – The mask tensor.

  • child_sampled_actions_batch (-) – The child sampled actions tensor.

  • unroll_step (-) – The unroll step.

Returns:

The policy loss tensor. - policy_entropy (torch.Tensor): The policy entropy tensor. - policy_entropy_loss (torch.Tensor): The policy entropy loss tensor. - target_policy_entropy (torch.Tensor): The target policy entropy tensor. - target_sampled_actions (torch.Tensor): The target sampled actions tensor. - mu (torch.Tensor): The mu tensor. - sigma (torch.Tensor): The sigma tensor.

Return type:

  • policy_loss (torch.Tensor)

_calculate_policy_loss_disc(policy_loss: Tensor, policy_logits: Tensor, target_policy: Tensor, mask_batch: Tensor, child_sampled_actions_batch: Tensor, unroll_step: int) Tuple[Tensor][source]
Overview:

Calculate the policy loss for discrete action space.

Parameters:
  • policy_loss (-) – The policy loss tensor.

  • policy_logits (-) – The policy logits tensor.

  • target_policy (-) – The target policy tensor.

  • mask_batch (-) – The mask tensor.

  • child_sampled_actions_batch (-) – The child sampled actions tensor.

  • unroll_step (-) – The unroll step.

Returns:

The policy loss tensor. - policy_entropy (torch.Tensor): The policy entropy tensor. - policy_entropy_loss (torch.Tensor): The policy entropy loss tensor. - target_policy_entropy (torch.Tensor): The target policy entropy tensor. - target_sampled_actions (torch.Tensor): The target sampled actions tensor.

Return type:

  • policy_loss (torch.Tensor)

_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(data: Tensor, action_mask: list = None, temperature: ndarray = 1, to_play=-1, epsilon: float = 0.25, ready_env_id: array = None)[source]
Overview:

The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • temperature (-) – The temperature of the policy.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • temperature: \((1, )\).

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_eval(data: Tensor, action_mask: list, to_play: -1, ready_env_id: array = None)[source]
Overview:

The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_learn(data: Tensor) Dict[str, float | int][source]
Overview:

The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.

Parameters:

data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.

Returns:

The information dict to be logged, which contains current learning loss and learning statistics.

Return type:

  • info_dict (Dict[str, Union[float, int]])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_target_obs_index_in_step_k(step)
Overview:

Get the begin index and end index of the target obs in step k.

Parameters:

step (-) – The current step k.

Returns:

The begin index of the target obs in step k. - end_index (int): The end index of the target obs in step k.

Return type:

  • beg_index (int)

Examples

>>> self._cfg.model.model_type = 'conv'
>>> self._cfg.model.image_channel = 3
>>> self._cfg.model.frame_stack_num = 4
>>> self._get_target_obs_index_in_step_k(0)
>>> (0, 12)
_get_train_sample(data)[source]
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Learn mode init method. Called by self.__init__. Initialize the learn model, optimizer and MCTS utils.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None[source]
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – the dict of policy learn state saved before.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_process_transition(obs, policy_output, timestep)[source]
Overview:

Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.

Parameters:
  • obs (-) – The observation of the current timestep.

  • policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.

  • timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.

Returns:

The processed transition data of the current timestep.

Return type:

  • transition (Dict[str, torch.Tensor])

_reset_collect(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the collector environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_eval(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the evaluator environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any][source]
Overview:

Return the state_dict of learn mode, usually including model and optimizer.

Returns:

the dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'action_type': 'fixed_action_space', 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collector_env_num': 8, 'cos_lr_scheduler': False, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'init_w': 0.003, 'learning_rate': 0.0001, 'lr_piecewise_constant_decay': False, 'lstm_horizon_len': 5, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'action_space_size': 6, 'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'fixed_sigma_value': 0.3, 'frame_stack_num': 1, 'image_channel': 1, 'lstm_hidden_size': 512, 'model_type': 'conv', 'norm_type': 'LN', 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': True, 'sigma_type': 'conditioned', 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'normalize_prob_of_sampled_actions': False, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'AdamW', 'policy_entropy_loss_weight': 0.005, 'policy_loss_type': 'cross_entropy', 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': True, 'ssl_loss_weight': 2, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': True, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting.

Returns:

model name and model import_names.
  • model_type (str): The model type used in this algorithm, which is registered in ModelRegistry.

  • import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_info (Tuple[str, List[str]])

Note

The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For Sampled EfficientZero, lzero.model.sampled_efficientzero_model.SampledEfficientZeroModel

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}

Stochastic MuZeroPolicy

class lzero.policy.stochastic_muzero.StochasticMuZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: MuZeroPolicy

Overview:

The policy class for Stochastic MuZero proposed in the paper https://openreview.net/pdf?id=X6D9bAHhBQ1.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None) Dict[source]
Overview:

The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • temperature (-) – The temperature of the policy.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, its shape is \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, its shape is \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • temperature: \((1, )\).

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_eval(data: Tensor, action_mask: list, to_play: int = -1, ready_env_id: array = None) Dict[source]
Overview:

The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_learn(data: Tuple[Tensor]) Dict[str, float | int][source]
Overview:

The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.

Parameters:

data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.

Returns:

The information dict to be logged, which contains current learning loss and learning statistics.

Return type:

  • info_dict (Dict[str, Union[float, int]])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_target_obs_index_in_step_k(step)
Overview:

Get the begin index and end index of the target obs in step k.

Parameters:

step (-) – The current step k.

Returns:

The begin index of the target obs in step k. - end_index (int): The end index of the target obs in step k.

Return type:

  • beg_index (int)

Examples

>>> self._cfg.model.model_type = 'conv'
>>> self._cfg.model.image_channel = 3
>>> self._cfg.model.frame_stack_num = 4
>>> self._get_target_obs_index_in_step_k(0)
>>> (0, 12)
_get_train_sample(data)[source]
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Learn mode init method. Called by self.__init__. Initialize the learn model, optimizer and MCTS utils.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None[source]
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – The dict of policy learn state saved before.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_process_transition(obs, policy_output, timestep)[source]
Overview:

Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.

Parameters:
  • obs (-) – The observation of the current timestep.

  • policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.

  • timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.

Returns:

The processed transition data of the current timestep.

Return type:

  • transition (Dict[str, torch.Tensor])

_reset_collect(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the collector environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_eval(data_id: List[int] | None = None) None
Overview:

Reset the observation and action for the evaluator environment.

Parameters:

data_id (-) – List of data ids to reset (not used in this implementation).

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any][source]
Overview:

Return the state_dict of learn mode, usually including model, target_model and optimizer.

Returns:

The dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'action_type': 'fixed_action_space', 'afterstate_policy_loss_weight': 1, 'afterstate_value_loss_weight': 0.25, 'analyze_chance_distribution': False, 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collector_env_num': 8, 'commitment_loss_weight': 1.0, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'learning_rate': 0, 'lr_piecewise_constant_decay': False, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'bias': True, 'categorical_distribution': True, 'chance_space_size': 2, 'continuous_action_space': False, 'frame_stack_num': 1, 'image_channel': 1, 'model_type': 'conv', 'num_channels': 64, 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'self_supervised_learning_loss': False, 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'n_episode': 8, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'Adam', 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': False, 'ssl_loss_weight': 0, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': 100, 'use_augmentation': False, 'use_max_priority_for_new_data': True, 'use_priority': True, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting.

Returns:

model name and model import_names.
  • model_type (str): The model type used in this algorithm, which is registered in ModelRegistry.

  • import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_info (Tuple[str, List[str]])

Note

The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For MuZero, lzero.model.muzero_model.MuZeroModel.

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}

UniZeroPolicy

class lzero.policy.unizero.UniZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]

Bases: MuZeroPolicy

Overview:

The policy class for UniZero, official implementation for paper UniZero: Generalized and Efficient Planning with Scalable LatentWorld Models. UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.

__init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
Overview:

Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including learn, collect, eval. The learn field is used to train the policy, the collect field is used to collect data for training, and the eval field is used to evaluate the policy. The enable_field is used to specify which field to initialize, if it is None, then all fields will be initialized.

Parameters:
  • cfg (-) – The final merged config used to initialize policy. For the default config, see the config attribute and its comments of policy class.

  • model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be set to the model instance created by outside caller.

  • enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in enable_field will be initialized, which is beneficial to save resources.

Note

For the derived policy class, it should implement the _init_learn, _init_collect, _init_eval method to initialize the corresponding field.

_abc_impl = <_abc._abc_data object>
_create_model(cfg: EasyDict, model: Module | None = None) Module
Overview:

Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to default_model method and cfg.model field. Otherwise, the model will be verified as an instance of torch.nn.Module and set to the model instance created by outside caller.

Parameters:
  • cfg (-) – The final merged config used to initialize policy.

  • model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.

Returns:

The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.

Return type:

  • model (torch.nn.Module)

Raises:

- RuntimeError – If the input model is not None and is not an instance of torch.nn.Module.

_forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None) Dict[source]
Overview:

The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • temperature (-) – The temperature of the policy.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • temperature: \((1, )\).

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_eval(data: Tensor, action_mask: list, to_play: int = -1, ready_env_id: array = None) Dict[source]
Overview:

The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.

Parameters:
  • data (-) – The input data, i.e. the observation.

  • action_mask (-) – The action mask, i.e. the action that cannot be selected.

  • to_play (-) – The player to play.

  • ready_env_id (-) – The id of the env that is ready to collect.

Shape:
  • data (torch.Tensor):
    • For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.

    • For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.

  • action_mask: \((N, action_space_size)\), where N is the number of collect_env.

  • to_play: \((N, 1)\), where N is the number of collect_env.

  • ready_env_id: None

Returns:

Dict type data, the keys including action, distributions, visit_count_distribution_entropy, value, pred_value, policy_logits.

Return type:

  • output (Dict[int, Any])

_forward_learn(data: Tuple[Tensor]) Dict[str, float | int][source]
Overview:

The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.

Parameters:

data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.

Returns:

The information dict to be logged, which contains current learning loss and learning statistics.

Return type:

  • info_dict (Dict[str, Union[float, int]])

_get_attribute(name: str) Any
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.

Parameters:

name (-) – The name of the attribute.

Returns:

The value of the attribute.

Return type:

  • value (Any)

Note

DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a NotImplementedError.

_get_batch_size() int | Dict[str, int]
_get_n_episode() int | None
_get_n_sample() int | None
_get_target_obs_index_in_step_k(step)
Overview:

Get the begin index and end index of the target obs in step k.

Parameters:

step (-) – The current step k.

Returns:

The begin index of the target obs in step k. - end_index (int): The end index of the target obs in step k.

Return type:

  • beg_index (int)

Examples

>>> self._cfg.model.model_type = 'conv'
>>> self._cfg.model.image_channel = 3
>>> self._cfg.model.frame_stack_num = 4
>>> self._get_target_obs_index_in_step_k(0)
>>> (0, 12)
_get_train_sample(data)
Overview:

For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in self._forward_learn method.

Parameters:

transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of self._process_transition method.

Returns:

The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.

Return type:

  • samples (List[Dict[str, Any]])

Note

We will vectorize process_transition and get_train_sample method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself

_init_collect() None[source]
Overview:

Collect mode init method. Called by self.__init__. Initialize the collect model and MCTS utils.

_init_eval() None[source]
Overview:

Evaluate mode init method. Called by self.__init__. Initialize the eval model and MCTS utils.

_init_learn() None[source]
Overview:

Learn mode init method. Called by self.__init__. Initialize the learn model, optimizer and MCTS utils.

_init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
Overview:

Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.

Parameters:
  • model (-) – The neural network model to be trained.

  • bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.

_load_state_dict_collect(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy collect state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_eval(state_dict: Dict[str, Any]) None
Overview:

Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.

Parameters:

state_dict (-) – The dict of policy eval state saved before.

Tip

If you want to only load some parts of model, you can simply set the strict argument in load_state_dict to False, or refer to ding.torch_utils.checkpoint_helper for more complicated operation.

_load_state_dict_learn(state_dict: Dict[str, Any]) None[source]
Overview:

Load the state_dict variable into policy learn mode.

Parameters:

state_dict (-) – The dict of policy learn state saved before.

_monitor_vars_learn() List[str][source]
Overview:

Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value _forward_learn.

_process_transition(obs, policy_output, timestep)
Overview:

Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.

Parameters:
  • obs (-) – The observation of the current timestep.

  • policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.

  • timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.

Returns:

The processed transition data of the current timestep.

Return type:

  • transition (Dict[str, torch.Tensor])

_reset_collect(env_id: int = None, current_steps: int = None, reset_init_data: bool = True) None[source]
Overview:

This method resets the collection process for a specific environment. It clears caches and memory when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data will be reset.

Parameters:
  • env_id (-) – The ID of the environment to reset. If None or list, the function returns immediately.

  • current_steps (-) – The current step count in the environment. Used to determine whether to clear caches.

  • reset_init_data (-) – Whether to reset the initial data. If True, the initial data will be reset.

_reset_eval(env_id: int = None, current_steps: int = None, reset_init_data: bool = True) None[source]
Overview:

This method resets the evaluation process for a specific environment. It clears caches and memory when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data will be reset.

Parameters:
  • env_id (-) – The ID of the environment to reset. If None or list, the function returns immediately.

  • current_steps (-) – The current step count in the environment. Used to determine whether to clear caches.

  • reset_init_data (-) – Whether to reset the initial data. If True, the initial data will be reset.

_reset_learn(data_id: List[int] | None = None) None
Overview:

Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If data_id is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to the data_id. For example, different trajectories in data_id will have different hidden state in RNN.

Parameters:

data_id (-) – The id of the data, which is used to reset the stateful variables specified by data_id.

Note

This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.

_set_attribute(name: str, value: Any) None
Overview:

In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as _{name}.

Parameters:
  • name (-) – The name of the attribute.

  • value (-) – The value of the attribute.

_state_dict_collect() Dict[str, Any]
Overview:

Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.

Returns:

The dict of current policy collect state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.

_state_dict_eval() Dict[str, Any]
Overview:

Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.

Returns:

The dict of current policy eval state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

Tip

Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.

_state_dict_learn() Dict[str, Any][source]
Overview:

Return the state_dict of learn mode, usually including model, target_model and optimizer.

Returns:

The dict of current policy learn state, for saving and restoring.

Return type:

  • state_dict (Dict[str, Any])

property cfg: EasyDict
class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new collect_function object from a sequence or iterable

_replace(**kwds)

Return a new collect_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

get_train_sample

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 7

process_transition

Alias for field number 1

reset

Alias for field number 3

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property collect_mode: collect_function
Overview:

Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.

Returns:

The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.collect_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_collect = policy.collect_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_collect.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
config = {'action_type': 'fixed_action_space', 'analysis_sim_norm': False, 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collect_with_pure_policy': False, 'collector_env_num': 8, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_freq': 2000, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 400, 'grad_clip_value': 5, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'learning_rate': 0.0001, 'lr_piecewise_constant_decay': False, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'analysis_sim_norm': False, 'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'frame_stack_num': 1, 'image_channel': 3, 'learn': {'learner': {'hook': {'save_ckpt_after_iter': 10000}}}, 'model_type': 'conv', 'norm_type': 'BN', 'num_channels': 64, 'num_res_blocks': 1, 'observation_shape': (3, 64, 64), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': True, 'support_scale': 50, 'world_model_cfg': {'action_space_size': 6, 'analysis_dormant_ratio': False, 'analysis_sim_norm': False, 'attention': 'causal', 'attn_pdrop': 0.1, 'context_length': 8, 'continuous_action_space': False, 'device': 'cpu', 'dormant_threshold': 0.025, 'embed_dim': 768, 'embed_pdrop': 0.1, 'env_num': 8, 'gamma': 1, 'group_size': 8, 'gru_gating': False, 'latent_recon_loss_weight': 0.0, 'max_blocks': 10, 'max_cache_size': 5000, 'max_tokens': 20, 'num_heads': 8, 'num_layers': 2, 'obs_type': 'image', 'perceptual_loss_weight': 0.0, 'policy_entropy_weight': 0.0001, 'predict_latent_loss_type': 'group_kl', 'resid_pdrop': 0.1, 'support_size': 101, 'tokens_per_block': 2}}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'num_simulations': 50, 'num_unroll_steps': 10, 'optim_type': 'AdamW', 'policy_entropy_loss_weight': 0, 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sample_type': 'transition', 'sampled_algo': False, 'ssl_loss_weight': 0, 'target_update_freq': 100, 'target_update_freq_for_intrinsic_reward': 1000, 'target_update_theta': 0.05, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'train_start_after_envsteps': 0, 'transform2string': False, 'type': 'unizero', 'update_per_collect': None, 'use_augmentation': False, 'use_priority': False, 'use_rnd_model': False, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
classmethod default_config() EasyDict
Overview:

Get the default config of policy. This method is used to create the default config of policy.

Returns:

The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.

Return type:

  • cfg (EasyDict)

Tip

This method will deepcopy the config attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.

default_model() Tuple[str, List[str]][source]
Overview:

Return this algorithm default model setting for demonstration.

Returns:

model name and model import_names.
  • model_type (str): The model type used in this algorithm, which is registered in ModelRegistry.

  • import_names (List[str]): The model class path list used in this algorithm.

Return type:

  • model_info (Tuple[str, List[str]])

Note

The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For MuZero, lzero.model.unizero_model.MuZeroModel

class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new eval_function object from a sequence or iterable

_replace(**kwds)

Return a new eval_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 2

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

load_state_dict

Alias for field number 5

reset

Alias for field number 1

set_attribute

Alias for field number 3

state_dict

Alias for field number 4

property eval_mode: eval_function
Overview:

Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.

Returns:

The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.eval_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_eval = policy.eval_mode
>>> obs = env_manager.ready_obs
>>> inference_output = policy_eval.forward(obs)
>>> next_obs, rew, done, info = env_manager.step(inference_output.action)
class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)

Bases: tuple

_asdict()

Return a new dict which maps field names to their values.

_field_defaults = {}
_fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
classmethod _make(iterable)

Make a new learn_function object from a sequence or iterable

_replace(**kwds)

Return a new learn_function object replacing specified fields with new values

count(value, /)

Return number of occurrences of value.

forward

Alias for field number 0

get_attribute

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

info

Alias for field number 2

load_state_dict

Alias for field number 7

monitor_vars

Alias for field number 3

reset

Alias for field number 1

set_attribute

Alias for field number 5

state_dict

Alias for field number 6

property learn_mode: learn_function
Overview:

Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.

Returns:

The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.

Return type:

  • interfaces (Policy.learn_function)

Examples

>>> policy = Policy(cfg, model)
>>> policy_learn = policy.learn_mode
>>> train_output = policy_learn.forward(data)
>>> state_dict = policy_learn.state_dict()
monitor_weights_and_grads(model)[source]
recompute_pos_emb_diff_and_clear_cache() None[source]
Overview:

Clear the caches and precompute positional embedding matrices in the model.

sync_gradients(model: Module) None
Overview:

Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.

Parameters:

model (-) – The model to synchronize gradients.

Note

This method is only used in multi-gpu training, and it should be called after backward method and before step method. The user can also use bp_update_sync config to control whether to synchronize gradients allreduce and optimizer updates.

total_field = {'collect', 'eval', 'learn'}