Policy
AlphaZeroPolicy
- class lzero.policy.alphazero.AlphaZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
Policy
- Overview:
The policy class for AlphaZero.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(obs: Dict, temperature: float = 1) Dict[str, Tensor] [source]
- Overview:
The forward function for collecting data in collect mode. Use real env to execute MCTS search.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
temperature (-) – The temperature for MCTS search.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_eval(obs: Dict) Dict[str, Tensor] [source]
- Overview:
The forward function for evaluating the current policy in eval mode, similar to
self._forward_collect
.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_learn(inputs: Dict[str, Tensor]) Dict[str, float] [source]
- Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means that the policy inputs some training batch data from the replay buffer and then returns the output result, including various training information such as loss value, policy entropy, q value, priority, and so on. This method is left to be implemented by the subclass, and more arguments can be added in
data
item if necessary.
- Parameters:
data (-) – The input data used for policy forward, including a batch of training samples. For each element in list, the key of the dict is the name of data items and the value is the corresponding data. Usually, in the
_forward_learn
method, data should be stacked in the batch dimension by some utility functions such asdefault_preprocess_learn
.- Returns:
The training information of policy forward, including some metrics for monitoring training such as loss, priority, q value, policy entropy, and some data for next step training such as priority. Note the output data item should be Python native scalar rather than PyTorch tensor, which is convenient for the outside to use.
- Return type:
output (
Dict[int, Any]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Initialize the learn mode of policy, including related attributes and modules. This method will be called in
__init__
method iflearn
field is inenable_field
. Almost different policies have its own learn mode, so this method must be overrided in subclass.
Note
For the member variables that need to be saved and loaded, please refer to the
_state_dict_learn
and_load_state_dict_learn
methods.Note
For the member variables that need to be monitored, please refer to the
_monitor_vars_learn
method.Note
If you want to set some spacial member variables in
_init_learn
method, you’d better name them with prefix_learn_
to avoid conflict with other modes, such asself._learn_attr1
.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs: Dict, model_output: Dict[str, Tensor], timestep: namedtuple) Dict [source]
- Overview:
Generate the dict type transition (one timestep) data from policy learning.
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in collecting indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in evaluation indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any]
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'batch_size': 256, 'collector_env_num': 8, 'cuda': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'grad_clip_value': 10, 'gumbel_algo': False, 'learning_rate': 0.2, 'manual_temperature_decay': False, 'mcts': {'max_moves': 512, 'num_simulations': 50, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25}, 'model': {'num_channels': 32, 'num_res_blocks': 1, 'observation_shape': (3, 6, 6)}, 'momentum': 0.9, 'multi_gpu': False, 'optim_type': 'SGD', 'other': {'replay_buffer': {'replay_buffer_size': 1000000, 'save_episode': False}}, 'piecewise_decay_lr_scheduler': True, 'replay_ratio': 0.25, 'sampled_algo': False, 'tensor_float_32': False, 'threshold_training_steps_for_final_lr': 500000, 'threshold_training_steps_for_final_temperature': 100000, 'torch_compile': False, 'update_per_collect': None, 'value_weight': 1.0, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
The model type used in this algorithm, which is registered in ModelRegistry. - import_names (
List[str]
): The model class path list used in this algorithm.- Return type:
model_type (
str
)
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
MuZeroPolicy
- class lzero.policy.muzero.MuZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
Policy
- Overview:
- if self._cfg.model.model_type in [“conv”, “mlp”]:
The policy class for MuZero.
- if self._cfg.model.model_type == [“conv_context”, “mlp_context”]:
The policy class for MuZero w/ Context, a variant of MuZero. This variant retains the same training settings as MuZero but diverges during inference by employing a k-step recursively predicted latent representation at the root node, proposed in the UniZero paper https://arxiv.org/abs/2406.10667.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None) Dict [source]
- Overview:
The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
temperature (-) – The temperature of the policy.
to_play (-) – The player to play.
epsilon (-) – The epsilon of the eps greedy exploration.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
temperature: \((1, )\).
to_play: \((N, 1)\), where N is the number of collect_env.
epsilon: \((1, )\).
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_eval(data: Tensor, action_mask: list, to_play: int = -1, ready_env_id: array = None) Dict [source]
- Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_learn(data: Tuple[Tensor]) Dict[str, float | int] [source]
- Overview:
The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.
- Parameters:
data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.
- Returns:
The information dict to be logged, which contains current learning loss and learning statistics.
- Return type:
info_dict (
Dict[str, Union[float, int]]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_target_obs_index_in_step_k(step)[source]
- Overview:
Get the begin index and end index of the target obs in step k.
- Parameters:
step (-) – The current step k.
- Returns:
The begin index of the target obs in step k. - end_index (
int
): The end index of the target obs in step k.- Return type:
beg_index (
int
)
Examples
>>> self._cfg.model.model_type = 'conv' >>> self._cfg.model.image_channel = 3 >>> self._cfg.model.frame_stack_num = 4 >>> self._get_target_obs_index_in_step_k(0) >>> (0, 12)
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Learn mode init method. Called by
self.__init__
. Initialize the learn model, optimizer and MCTS utils.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None [source]
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs, policy_output, timestep)[source]
- Overview:
Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.
- Parameters:
obs (-) – The observation of the current timestep.
policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.
timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.
- Returns:
The processed transition data of the current timestep.
- Return type:
transition (
Dict[str, torch.Tensor]
)
- _reset_collect(data_id: List[int] | None = None) None [source]
- Overview:
Reset the observation and action for the collector environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_eval(data_id: List[int] | None = None) None [source]
- Overview:
Reset the observation and action for the evaluator environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any] [source]
- Overview:
Return the state_dict of learn mode, usually including model, target_model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'action_type': 'fixed_action_space', 'analysis_dormant_ratio': False, 'analysis_sim_norm': False, 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'cal_dormant_ratio': False, 'collect_with_pure_policy': False, 'collector_env_num': 8, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'learning_rate': 0.2, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'analysis_dormant_ratio': False, 'analysis_sim_norm': False, 'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'frame_stack_num': 1, 'harmony_balance': False, 'image_channel': 1, 'model_type': 'conv', 'norm_type': 'BN', 'num_channels': 64, 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': False, 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'num_segments': 8, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'SGD', 'piecewise_decay_lr_scheduler': True, 'policy_entropy_weight': 0, 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reuse_search': False, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': False, 'ssl_loss_weight': 0, 'target_update_freq': 100, 'target_update_freq_for_intrinsic_reward': 1000, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': False, 'use_rnd_model': False, 'use_ture_chance_label_in_chance_encoder': False, 'use_wandb': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
- model name and model import_names.
model_type (
str
): The model type used in this algorithm, which is registered in ModelRegistry.import_names (
List[str]
): The model class path list used in this algorithm.
- Return type:
model_info (
Tuple[str, List[str]]
)
Note
The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For MuZero,
lzero.model.muzero_model.MuZeroModel
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- set_train_iter_env_step(train_iter, env_step) None [source]
- Overview:
Set the train_iter and env_step for the policy.
- Parameters:
train_iter (-) – The train_iter for the policy.
env_step (-) – The env_step for the policy.
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
EfficientZeroPolicy
- class lzero.policy.efficientzero.EfficientZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
MuZeroPolicy
- Overview:
The policy class for EfficientZero proposed in the paper https://arxiv.org/abs/2111.00210.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None)[source]
- Overview:
The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
temperature (-) – The temperature of the policy.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
temperature: \((1, )\).
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_eval(data: Tensor, action_mask: list, to_play: -1, ready_env_id: array = None)[source]
- Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_learn(data: Tensor) Dict[str, float | int] [source]
- Overview:
The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.
- Parameters:
data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.
- Returns:
The information dict to be logged, which contains current learning loss and learning statistics.
- Return type:
info_dict (
Dict[str, Union[float, int]]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_target_obs_index_in_step_k(step)
- Overview:
Get the begin index and end index of the target obs in step k.
- Parameters:
step (-) – The current step k.
- Returns:
The begin index of the target obs in step k. - end_index (
int
): The end index of the target obs in step k.- Return type:
beg_index (
int
)
Examples
>>> self._cfg.model.model_type = 'conv' >>> self._cfg.model.image_channel = 3 >>> self._cfg.model.frame_stack_num = 4 >>> self._get_target_obs_index_in_step_k(0) >>> (0, 12)
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Learn mode init method. Called by
self.__init__
. Initialize the learn model, optimizer and MCTS utils.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None [source]
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – the dict of policy learn state saved before.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs, policy_output, timestep)[source]
- Overview:
Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.
- Parameters:
obs (-) – The observation of the current timestep.
policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.
timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.
- Returns:
The processed transition data of the current timestep.
- Return type:
transition (
Dict[str, torch.Tensor]
)
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the collector environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the evaluator environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any] [source]
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
the dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'action_type': 'fixed_action_space', 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collect_with_pure_policy': False, 'collector_env_num': 8, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'learning_rate': 0.2, 'lstm_horizon_len': 5, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'frame_stack_num': 1, 'image_channel': 1, 'lstm_hidden_size': 512, 'model_type': 'conv', 'norm_type': 'BN', 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': True, 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'SGD', 'piecewise_decay_lr_scheduler': True, 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reuse_search': False, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': False, 'ssl_loss_weight': 2, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': False, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting.
- Returns:
- model name and model import_names.
model_type (
str
): The model type used in this algorithm, which is registered in ModelRegistry.import_names (
List[str]
): The model class path list used in this algorithm.
- Return type:
model_info (
Tuple[str, List[str]]
)
Note
The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For EfficientZero,
lzero.model.efficientzero_model.EfficientZeroModel
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- set_train_iter_env_step(train_iter, env_step) None
- Overview:
Set the train_iter and env_step for the policy.
- Parameters:
train_iter (-) – The train_iter for the policy.
env_step (-) – The env_step for the policy.
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
Gumbel AlphaZeroPolicy
- class lzero.policy.gumbel_alphazero.GumbelAlphaZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
Policy
- Overview:
The policy class for GumbelAlphaZero.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(obs: Dict, temperature: float = 1) Dict[str, Tensor] [source]
- Overview:
The forward function for collecting data in collect mode. Use real env to execute MCTS search.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
temperature (-) – The temperature for MCTS search.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_eval(obs: Dict) Dict[str, Tensor] [source]
- Overview:
The forward function for evaluating the current policy in eval mode, similar to
self._forward_collect
.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_learn(inputs: Dict[str, Tensor]) Dict[str, float] [source]
- Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means that the policy inputs some training batch data from the replay buffer and then returns the output result, including various training information such as loss value, policy entropy, q value, priority, and so on. This method is left to be implemented by the subclass, and more arguments can be added in
data
item if necessary.
- Parameters:
data (-) – The input data used for policy forward, including a batch of training samples. For each element in list, the key of the dict is the name of data items and the value is the corresponding data. Usually, in the
_forward_learn
method, data should be stacked in the batch dimension by some utility functions such asdefault_preprocess_learn
.- Returns:
The training information of policy forward, including some metrics for monitoring training such as loss, priority, q value, policy entropy, and some data for next step training such as priority. Note the output data item should be Python native scalar rather than PyTorch tensor, which is convenient for the outside to use.
- Return type:
output (
Dict[int, Any]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Initialize the learn mode of policy, including related attributes and modules. This method will be called in
__init__
method iflearn
field is inenable_field
. Almost different policies have its own learn mode, so this method must be overrided in subclass.
Note
For the member variables that need to be saved and loaded, please refer to the
_state_dict_learn
and_load_state_dict_learn
methods.Note
For the member variables that need to be monitored, please refer to the
_monitor_vars_learn
method.Note
If you want to set some spacial member variables in
_init_learn
method, you’d better name them with prefix_learn_
to avoid conflict with other modes, such asself._learn_attr1
.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs: Dict, model_output: Dict[str, Tensor], timestep: namedtuple) Dict [source]
- Overview:
Generate the dict type transition (one timestep) data from policy learning.
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in collecting indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in evaluation indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any]
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'batch_size': 256, 'collector_env_num': 8, 'cuda': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'grad_clip_value': 10, 'learning_rate': 0.2, 'manual_temperature_decay': False, 'mcts': {'action_space_size': 9, 'continuous_action_space': False, 'gumbel_rng': 0.0, 'gumbel_scale': 10.0, 'legal_actions': None, 'max_moves': 512, 'max_num_considered_actions': 6, 'maxvisit_init': 50, 'num_of_sampled_actions': 2, 'num_simulations': 50, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'value_scale': 0.1}, 'mcts_ctree': True, 'model': {'num_channels': 32, 'num_res_blocks': 1, 'observation_shape': (3, 6, 6)}, 'momentum': 0.9, 'optim_type': 'SGD', 'other': {'replay_buffer': {'replay_buffer_size': 1000000, 'save_episode': False}}, 'piecewise_decay_lr_scheduler': True, 'replay_ratio': 0.25, 'sampled_algo': False, 'tensor_float_32': False, 'threshold_training_steps_for_final_lr': 500000, 'threshold_training_steps_for_final_temperature': 100000, 'torch_compile': False, 'type': 'gumbel_alphazero', 'update_per_collect': None, 'value_weight': 1.0, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
The model type used in this algorithm, which is registered in ModelRegistry. - import_names (
List[str]
): The model class path list used in this algorithm.- Return type:
model_type (
str
)
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
Gumbel MuZeroPolicy
- class lzero.policy.gumbel_muzero.GumbelMuZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
MuZeroPolicy
- Overview:
The policy class for Gumbel MuZero proposed in the paper https://openreview.net/forum?id=bERaNdoegnO.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None) Dict [source]
- Overview:
The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
temperature (-) – The temperature of the policy.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
temperature: \((1, )\).
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,roots_completed_value
,improved_policy_probs
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_eval(data: Tensor, action_mask: list, to_play: int = -1, ready_env_id: array = None) Dict [source]
- Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_learn(data: Tensor) Dict[str, float | int] [source]
- Overview:
The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.
- Parameters:
data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.
- Returns:
The information dict to be logged, which contains current learning loss and learning statistics.
- Return type:
info_dict (
Dict[str, Union[float, int]]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_target_obs_index_in_step_k(step)
- Overview:
Get the begin index and end index of the target obs in step k.
- Parameters:
step (-) – The current step k.
- Returns:
The begin index of the target obs in step k. - end_index (
int
): The end index of the target obs in step k.- Return type:
beg_index (
int
)
Examples
>>> self._cfg.model.model_type = 'conv' >>> self._cfg.model.image_channel = 3 >>> self._cfg.model.frame_stack_num = 4 >>> self._get_target_obs_index_in_step_k(0) >>> (0, 12)
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Learn mode init method. Called by
self.__init__
. Initialize the learn model, optimizer and MCTS utils.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None [source]
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs, policy_output, timestep)[source]
- Overview:
Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.
- Parameters:
obs (-) – The observation of the current timestep.
policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.
timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.
- Returns:
The processed transition data of the current timestep.
- Return type:
transition (
Dict[str, torch.Tensor]
)
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the collector environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the evaluator environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any] [source]
- Overview:
Return the state_dict of learn mode, usually including model, target_model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'action_type': 'fixed_action_space', 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collector_env_num': 8, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': True, 'ignore_done': False, 'learning_rate': 0.2, 'manual_temperature_decay': False, 'max_num_considered_actions': 4, 'mcts_ctree': True, 'model': {'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'frame_stack_num': 1, 'image_channel': 1, 'model_type': 'conv', 'norm_type': 'BN', 'num_channels': 64, 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': False, 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'SGD', 'piecewise_decay_lr_scheduler': True, 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': False, 'ssl_loss_weight': 0, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': True, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
- model name and model import_names.
model_type (
str
): The model type used in this algorithm, which is registered in ModelRegistry.import_names (
List[str]
): The model class path list used in this algorithm.
- Return type:
model_info (
Tuple[str, List[str]]
)
Note
The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For MuZero,
lzero.model.muzero_model.MuZeroModel
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- set_train_iter_env_step(train_iter, env_step) None
- Overview:
Set the train_iter and env_step for the policy.
- Parameters:
train_iter (-) – The train_iter for the policy.
env_step (-) – The env_step for the policy.
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
Sampled AlphaZeroPolicy
- class lzero.policy.sampled_alphazero.SampledAlphaZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
Policy
- Overview:
The policy class for Sampled AlphaZero.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _calculate_policy_loss_disc(policy_probs: Tensor, target_policy: Tensor, target_sampled_actions: Tensor, valid_action_lengths: Tensor) Tensor [source]
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(obs: Dict, temperature: float = 1) Dict[str, Tensor] [source]
- Overview:
The forward function for collecting data in collect mode. Use real env to execute MCTS search.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
temperature (-) – The temperature for MCTS search.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_eval(obs: Dict) Dict[str, Tensor] [source]
- Overview:
The forward function for evaluating the current policy in eval mode, similar to
self._forward_collect
.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_learn(inputs: Dict[str, Tensor]) Dict[str, float] [source]
- Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means that the policy inputs some training batch data from the replay buffer and then returns the output result, including various training information such as loss value, policy entropy, q value, priority, and so on. This method is left to be implemented by the subclass, and more arguments can be added in
data
item if necessary.
- Parameters:
data (-) – The input data used for policy forward, including a batch of training samples. For each element in list, the key of the dict is the name of data items and the value is the corresponding data. Usually, in the
_forward_learn
method, data should be stacked in the batch dimension by some utility functions such asdefault_preprocess_learn
.- Returns:
The training information of policy forward, including some metrics for monitoring training such as loss, priority, q value, policy entropy, and some data for next step training such as priority. Note the output data item should be Python native scalar rather than PyTorch tensor, which is convenient for the outside to use.
- Return type:
output (
Dict[int, Any]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Initialize the learn mode of policy, including related attributes and modules. This method will be called in
__init__
method iflearn
field is inenable_field
. Almost different policies have its own learn mode, so this method must be overrided in subclass.
Note
For the member variables that need to be saved and loaded, please refer to the
_state_dict_learn
and_load_state_dict_learn
methods.Note
For the member variables that need to be monitored, please refer to the
_monitor_vars_learn
method.Note
If you want to set some spacial member variables in
_init_learn
method, you’d better name them with prefix_learn_
to avoid conflict with other modes, such asself._learn_attr1
.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs: Dict, model_output: Dict[str, Tensor], timestep: namedtuple) Dict [source]
- Overview:
Generate the dict type transition (one timestep) data from policy learning.
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in collecting indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in evaluation indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any]
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'batch_size': 256, 'collector_env_num': 8, 'cuda': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'grad_clip_value': 10, 'learning_rate': 0.2, 'manual_temperature_decay': False, 'mcts': {'action_space_size': 9, 'continuous_action_space': False, 'legal_actions': None, 'max_moves': 512, 'num_of_sampled_actions': 2, 'num_simulations': 50, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25}, 'mcts_ctree': True, 'model': {'num_channels': 32, 'num_res_blocks': 1, 'observation_shape': (3, 6, 6)}, 'momentum': 0.9, 'normalize_prob_of_sampled_actions': False, 'optim_type': 'SGD', 'other': {'replay_buffer': {'replay_buffer_size': 1000000, 'save_episode': False}}, 'piecewise_decay_lr_scheduler': True, 'policy_loss_type': 'cross_entropy', 'replay_ratio': 0.25, 'sampled_algo': False, 'tensor_float_32': False, 'threshold_training_steps_for_final_lr': 500000, 'threshold_training_steps_for_final_temperature': 100000, 'torch_compile': False, 'type': 'alphazero', 'update_per_collect': None, 'value_weight': 1.0, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
The model type used in this algorithm, which is registered in ModelRegistry. - import_names (
List[str]
): The model class path list used in this algorithm.- Return type:
model_type (
str
)
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
Sampled MuZeroPolicy
- class lzero.policy.sampled_muzero.SampledMuZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
MuZeroPolicy
- Overview:
The policy class for Sampled MuZero proposed in the paper https://arxiv.org/abs/2104.06303.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _calculate_policy_loss_cont(policy_loss: Tensor, policy_logits: Tensor, target_policy: Tensor, mask_batch: Tensor, child_sampled_actions_batch: Tensor, unroll_step: int) Tuple[Tensor] [source]
- Overview:
Calculate the policy loss for continuous action space.
- Parameters:
policy_loss (-) – The policy loss tensor.
policy_logits (-) – The policy logits tensor.
target_policy (-) – The target policy tensor.
mask_batch (-) – The mask tensor.
child_sampled_actions_batch (-) – The child sampled actions tensor.
unroll_step (-) – The unroll step.
- Returns:
The policy loss tensor. - policy_entropy (
torch.Tensor
): The policy entropy tensor. - policy_entropy_loss (torch.Tensor
): The policy entropy loss tensor. - target_policy_entropy (torch.Tensor
): The target policy entropy tensor. - target_sampled_actions (torch.Tensor
): The target sampled actions tensor. - mu (torch.Tensor
): The mu tensor. - sigma (torch.Tensor
): The sigma tensor.- Return type:
policy_loss (
torch.Tensor
)
- _calculate_policy_loss_disc(policy_loss: Tensor, policy_logits: Tensor, target_policy: Tensor, mask_batch: Tensor, child_sampled_actions_batch: Tensor, unroll_step: int) Tuple[Tensor] [source]
- Overview:
Calculate the policy loss for discrete action space.
- Parameters:
policy_loss (-) – The policy loss tensor.
policy_logits (-) – The policy logits tensor.
target_policy (-) – The target policy tensor.
mask_batch (-) – The mask tensor.
child_sampled_actions_batch (-) – The child sampled actions tensor.
unroll_step (-) – The unroll step.
- Returns:
The policy loss tensor. - policy_entropy (
torch.Tensor
): The policy entropy tensor. - policy_entropy_loss (torch.Tensor
): The policy entropy loss tensor. - target_policy_entropy (torch.Tensor
): The target policy entropy tensor. - target_sampled_actions (torch.Tensor
): The target sampled actions tensor.- Return type:
policy_loss (
torch.Tensor
)
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(data: Tensor, action_mask: list = None, temperature: ndarray = 1, to_play=-1, epsilon: float = 0.25, ready_env_id: array = None)[source]
- Overview:
The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
temperature (-) – The temperature of the policy.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
temperature: \((1, )\).
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_eval(data: Tensor, action_mask: list, to_play: -1, ready_env_id: array = None)[source]
- Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_learn(data: Tensor) Dict[str, float | int] [source]
- Overview:
The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.
- Parameters:
data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.
- Returns:
The information dict to be logged, which contains current learning loss and learning statistics.
- Return type:
info_dict (
Dict[str, Union[float, int]]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_target_obs_index_in_step_k(step)
- Overview:
Get the begin index and end index of the target obs in step k.
- Parameters:
step (-) – The current step k.
- Returns:
The begin index of the target obs in step k. - end_index (
int
): The end index of the target obs in step k.- Return type:
beg_index (
int
)
Examples
>>> self._cfg.model.model_type = 'conv' >>> self._cfg.model.image_channel = 3 >>> self._cfg.model.frame_stack_num = 4 >>> self._get_target_obs_index_in_step_k(0) >>> (0, 12)
- _get_train_sample(data)
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Learn mode init method. Called by
self.__init__
. Initialize the learn model, optimizer and MCTS utils.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs, policy_output, timestep)
- Overview:
Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.
- Parameters:
obs (-) – The observation of the current timestep.
policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.
timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.
- Returns:
The processed transition data of the current timestep.
- Return type:
transition (
Dict[str, torch.Tensor]
)
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the collector environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the evaluator environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any]
- Overview:
Return the state_dict of learn mode, usually including model, target_model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'action_type': 'fixed_action_space', 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collector_env_num': 8, 'cos_lr_scheduler': False, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'init_w': 0.003, 'learning_rate': 0.0001, 'lstm_horizon_len': 5, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'action_space_size': 6, 'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'fixed_sigma_value': 0.3, 'frame_stack_num': 1, 'image_channel': 1, 'lstm_hidden_size': 512, 'model_type': 'conv', 'norm_type': 'LN', 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': True, 'sigma_type': 'conditioned', 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'normalize_prob_of_sampled_actions': False, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'AdamW', 'piecewise_decay_lr_scheduler': False, 'policy_entropy_weight': 0.005, 'policy_loss_type': 'cross_entropy', 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': True, 'ssl_loss_weight': 2, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': False, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting.
- Returns:
- model name and model import_names.
model_type (
str
): The model type used in this algorithm, which is registered in ModelRegistry.import_names (
List[str]
): The model class path list used in this algorithm.
- Return type:
model_info (
Tuple[str, List[str]]
)
Note
The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For Sampled MuZero,
lzero.model.sampled_muzero_model.SampledMuZeroModel
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- set_train_iter_env_step(train_iter, env_step) None
- Overview:
Set the train_iter and env_step for the policy.
- Parameters:
train_iter (-) – The train_iter for the policy.
env_step (-) – The env_step for the policy.
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
Sampled EfficientZeroPolicy
- class lzero.policy.sampled_efficientzero.SampledEfficientZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
MuZeroPolicy
- Overview:
The policy class for Sampled EfficientZero proposed in the paper https://arxiv.org/abs/2104.06303.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _calculate_policy_loss_cont(policy_loss: Tensor, policy_logits: Tensor, target_policy: Tensor, mask_batch: Tensor, child_sampled_actions_batch: Tensor, unroll_step: int) Tuple[Tensor] [source]
- Overview:
Calculate the policy loss for continuous action space.
- Parameters:
policy_loss (-) – The policy loss tensor.
policy_logits (-) – The policy logits tensor.
target_policy (-) – The target policy tensor.
mask_batch (-) – The mask tensor.
child_sampled_actions_batch (-) – The child sampled actions tensor.
unroll_step (-) – The unroll step.
- Returns:
The policy loss tensor. - policy_entropy (
torch.Tensor
): The policy entropy tensor. - policy_entropy_loss (torch.Tensor
): The policy entropy loss tensor. - target_policy_entropy (torch.Tensor
): The target policy entropy tensor. - target_sampled_actions (torch.Tensor
): The target sampled actions tensor. - mu (torch.Tensor
): The mu tensor. - sigma (torch.Tensor
): The sigma tensor.- Return type:
policy_loss (
torch.Tensor
)
- _calculate_policy_loss_disc(policy_loss: Tensor, policy_logits: Tensor, target_policy: Tensor, mask_batch: Tensor, child_sampled_actions_batch: Tensor, unroll_step: int) Tuple[Tensor] [source]
- Overview:
Calculate the policy loss for discrete action space.
- Parameters:
policy_loss (-) – The policy loss tensor.
policy_logits (-) – The policy logits tensor.
target_policy (-) – The target policy tensor.
mask_batch (-) – The mask tensor.
child_sampled_actions_batch (-) – The child sampled actions tensor.
unroll_step (-) – The unroll step.
- Returns:
The policy loss tensor. - policy_entropy (
torch.Tensor
): The policy entropy tensor. - policy_entropy_loss (torch.Tensor
): The policy entropy loss tensor. - target_policy_entropy (torch.Tensor
): The target policy entropy tensor. - target_sampled_actions (torch.Tensor
): The target sampled actions tensor.- Return type:
policy_loss (
torch.Tensor
)
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(data: Tensor, action_mask: list = None, temperature: ndarray = 1, to_play=-1, epsilon: float = 0.25, ready_env_id: array = None)[source]
- Overview:
The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
temperature (-) – The temperature of the policy.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
temperature: \((1, )\).
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_eval(data: Tensor, action_mask: list, to_play: -1, ready_env_id: array = None)[source]
- Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_learn(data: Tensor) Dict[str, float | int] [source]
- Overview:
The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.
- Parameters:
data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.
- Returns:
The information dict to be logged, which contains current learning loss and learning statistics.
- Return type:
info_dict (
Dict[str, Union[float, int]]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_target_obs_index_in_step_k(step)
- Overview:
Get the begin index and end index of the target obs in step k.
- Parameters:
step (-) – The current step k.
- Returns:
The begin index of the target obs in step k. - end_index (
int
): The end index of the target obs in step k.- Return type:
beg_index (
int
)
Examples
>>> self._cfg.model.model_type = 'conv' >>> self._cfg.model.image_channel = 3 >>> self._cfg.model.frame_stack_num = 4 >>> self._get_target_obs_index_in_step_k(0) >>> (0, 12)
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Learn mode init method. Called by
self.__init__
. Initialize the learn model, optimizer and MCTS utils.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None [source]
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – the dict of policy learn state saved before.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs, policy_output, timestep)[source]
- Overview:
Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.
- Parameters:
obs (-) – The observation of the current timestep.
policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.
timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.
- Returns:
The processed transition data of the current timestep.
- Return type:
transition (
Dict[str, torch.Tensor]
)
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the collector environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the evaluator environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any] [source]
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
the dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'action_type': 'fixed_action_space', 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collector_env_num': 8, 'cos_lr_scheduler': False, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'init_w': 0.003, 'learning_rate': 0.0001, 'lstm_horizon_len': 5, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'action_space_size': 6, 'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'discrete_action_encoding_type': 'one_hot', 'fixed_sigma_value': 0.3, 'frame_stack_num': 1, 'image_channel': 1, 'lstm_hidden_size': 512, 'model_type': 'conv', 'norm_type': 'LN', 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': True, 'sigma_type': 'conditioned', 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'normalize_prob_of_sampled_actions': False, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'AdamW', 'piecewise_decay_lr_scheduler': False, 'policy_entropy_weight': 0.005, 'policy_loss_type': 'cross_entropy', 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': True, 'ssl_loss_weight': 2, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': None, 'use_augmentation': False, 'use_priority': True, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting.
- Returns:
- model name and model import_names.
model_type (
str
): The model type used in this algorithm, which is registered in ModelRegistry.import_names (
List[str]
): The model class path list used in this algorithm.
- Return type:
model_info (
Tuple[str, List[str]]
)
Note
The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For Sampled EfficientZero,
lzero.model.sampled_efficientzero_model.SampledEfficientZeroModel
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- set_train_iter_env_step(train_iter, env_step) None
- Overview:
Set the train_iter and env_step for the policy.
- Parameters:
train_iter (-) – The train_iter for the policy.
env_step (-) – The env_step for the policy.
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
Stochastic MuZeroPolicy
- class lzero.policy.stochastic_muzero.StochasticMuZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
MuZeroPolicy
- Overview:
The policy class for Stochastic MuZero proposed in the paper https://openreview.net/pdf?id=X6D9bAHhBQ1.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None) Dict [source]
- Overview:
The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
temperature (-) – The temperature of the policy.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, its shape is \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, its shape is \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
temperature: \((1, )\).
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_eval(data: Tensor, action_mask: list, to_play: int = -1, ready_env_id: array = None) Dict [source]
- Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_learn(data: Tuple[Tensor]) Dict[str, float | int] [source]
- Overview:
The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.
- Parameters:
data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.
- Returns:
The information dict to be logged, which contains current learning loss and learning statistics.
- Return type:
info_dict (
Dict[str, Union[float, int]]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_target_obs_index_in_step_k(step)
- Overview:
Get the begin index and end index of the target obs in step k.
- Parameters:
step (-) – The current step k.
- Returns:
The begin index of the target obs in step k. - end_index (
int
): The end index of the target obs in step k.- Return type:
beg_index (
int
)
Examples
>>> self._cfg.model.model_type = 'conv' >>> self._cfg.model.image_channel = 3 >>> self._cfg.model.frame_stack_num = 4 >>> self._get_target_obs_index_in_step_k(0) >>> (0, 12)
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Learn mode init method. Called by
self.__init__
. Initialize the learn model, optimizer and MCTS utils.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None [source]
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs, policy_output, timestep)[source]
- Overview:
Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.
- Parameters:
obs (-) – The observation of the current timestep.
policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.
timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.
- Returns:
The processed transition data of the current timestep.
- Return type:
transition (
Dict[str, torch.Tensor]
)
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the collector environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset the observation and action for the evaluator environment.
- Parameters:
data_id (-) – List of data ids to reset (not used in this implementation).
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any] [source]
- Overview:
Return the state_dict of learn mode, usually including model, target_model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'action_type': 'fixed_action_space', 'afterstate_policy_loss_weight': 1, 'afterstate_value_loss_weight': 0.25, 'analyze_chance_distribution': False, 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collector_env_num': 8, 'commitment_loss_weight': 1.0, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_offline': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 200, 'grad_clip_value': 10, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'learning_rate': 0, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'bias': True, 'categorical_distribution': True, 'chance_space_size': 2, 'continuous_action_space': False, 'frame_stack_num': 1, 'image_channel': 1, 'model_type': 'conv', 'num_channels': 64, 'num_res_blocks': 1, 'observation_shape': (4, 96, 96), 'self_supervised_learning_loss': False, 'support_scale': 300}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'n_episode': 8, 'num_simulations': 50, 'num_unroll_steps': 5, 'optim_type': 'Adam', 'piecewise_decay_lr_scheduler': False, 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'reanalyze_noise': True, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sampled_algo': False, 'ssl_loss_weight': 0, 'target_update_freq': 100, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 100000, 'transform2string': False, 'update_per_collect': 100, 'use_augmentation': False, 'use_max_priority_for_new_data': True, 'use_priority': True, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting.
- Returns:
- model name and model import_names.
model_type (
str
): The model type used in this algorithm, which is registered in ModelRegistry.import_names (
List[str]
): The model class path list used in this algorithm.
- Return type:
model_info (
Tuple[str, List[str]]
)
Note
The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For MuZero,
lzero.model.muzero_model.MuZeroModel
.
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- set_train_iter_env_step(train_iter, env_step) None
- Overview:
Set the train_iter and env_step for the policy.
- Parameters:
train_iter (-) – The train_iter for the policy.
env_step (-) – The env_step for the policy.
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
UniZeroPolicy
- class lzero.policy.unizero.UniZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
MuZeroPolicy
- Overview:
The policy class for UniZero, official implementation for paper UniZero: Generalized and Efficient Planning with Scalable LatentWorld Models. UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(data: Tensor, action_mask: list = None, temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, ready_env_id: array = None) Dict [source]
- Overview:
The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
temperature (-) – The temperature of the policy.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
temperature: \((1, )\).
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_eval(data: Tensor, action_mask: list, to_play: int = -1, ready_env_id: array = None) Dict [source]
- Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode.
- Parameters:
data (-) – The input data, i.e. the observation.
action_mask (-) – The action mask, i.e. the action that cannot be selected.
to_play (-) – The player to play.
ready_env_id (-) – The id of the env that is ready to collect.
- Shape:
- data (
torch.Tensor
): For Atari, \((N, C*S, H, W)\), where N is the number of collect_env, C is the number of channels, S is the number of stacked frames, H is the height of the image, W is the width of the image.
For lunarlander, \((N, O)\), where N is the number of collect_env, O is the observation space size.
- data (
action_mask: \((N, action_space_size)\), where N is the number of collect_env.
to_play: \((N, 1)\), where N is the number of collect_env.
ready_env_id: None
- Returns:
Dict type data, the keys including
action
,distributions
,visit_count_distribution_entropy
,value
,pred_value
,policy_logits
.- Return type:
output (
Dict[int, Any]
)
- _forward_learn(data: Tuple[Tensor]) Dict[str, float | int] [source]
- Overview:
The forward function for learning policy in learn mode, which is the core of the learning process. The data is sampled from replay buffer. The loss is calculated by the loss function and the loss is backpropagated to update the model.
- Parameters:
data (-) – The data sampled from replay buffer, which is a tuple of tensors. The first tensor is the current_batch, the second tensor is the target_batch.
- Returns:
The information dict to be logged, which contains current learning loss and learning statistics.
- Return type:
info_dict (
Dict[str, Union[float, int]]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_target_obs_index_in_step_k(step)
- Overview:
Get the begin index and end index of the target obs in step k.
- Parameters:
step (-) – The current step k.
- Returns:
The begin index of the target obs in step k. - end_index (
int
): The end index of the target obs in step k.- Return type:
beg_index (
int
)
Examples
>>> self._cfg.model.model_type = 'conv' >>> self._cfg.model.image_channel = 3 >>> self._cfg.model.frame_stack_num = 4 >>> self._get_target_obs_index_in_step_k(0) >>> (0, 12)
- _get_train_sample(data)
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Learn mode init method. Called by
self.__init__
. Initialize the learn model, optimizer and MCTS utils.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None [source]
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs, policy_output, timestep)
- Overview:
Process and pack one timestep transition data into a dict, such as <s, a, r, s’, done>. Some policies need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), so this method is left to be implemented by the subclass.
- Parameters:
obs (-) – The observation of the current timestep.
policy_output (-) – The output of the policy network with the observation as input. Usually, it contains the action and the logit of the action.
timestep (-) – The execution result namedtuple returned by the environment step method, except all the elements have been transformed into tensor data. Usually, it contains the next obs, reward, done, info, etc.
- Returns:
The processed transition data of the current timestep.
- Return type:
transition (
Dict[str, torch.Tensor]
)
- _reset_collect(env_id: int = None, current_steps: int = None, reset_init_data: bool = True) None [source]
- Overview:
This method resets the collection process for a specific environment. It clears caches and memory when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data will be reset.
- Parameters:
env_id (-) – The ID of the environment to reset. If None or list, the function returns immediately.
current_steps (-) – The current step count in the environment. Used to determine whether to clear caches.
reset_init_data (-) – Whether to reset the initial data. If True, the initial data will be reset.
- _reset_eval(env_id: int = None, current_steps: int = None, reset_init_data: bool = True) None [source]
- Overview:
This method resets the evaluation process for a specific environment. It clears caches and memory when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data will be reset.
- Parameters:
env_id (-) – The ID of the environment to reset. If None or list, the function returns immediately.
current_steps (-) – The current step count in the environment. Used to determine whether to clear caches.
reset_init_data (-) – Whether to reset the initial data. If True, the initial data will be reset.
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any] [source]
- Overview:
Return the state_dict of learn mode, usually including model, target_model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'action_type': 'fixed_action_space', 'analysis_sim_norm': False, 'augmentation': ['shift', 'intensity'], 'batch_size': 256, 'battle_mode': 'play_with_bot_mode', 'collect_with_pure_policy': False, 'collector_env_num': 8, 'cuda': True, 'discount_factor': 0.997, 'env_type': 'not_board_games', 'eps': {'decay': 100000, 'end': 0.05, 'eps_greedy_exploration_in_collect': False, 'start': 1.0, 'type': 'linear'}, 'eval_freq': 2000, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'game_segment_length': 400, 'grad_clip_value': 20, 'gray_scale': False, 'gumbel_algo': False, 'ignore_done': False, 'learning_rate': 0.0001, 'manual_temperature_decay': False, 'mcts_ctree': True, 'model': {'analysis_sim_norm': False, 'bias': True, 'categorical_distribution': True, 'continuous_action_space': False, 'frame_stack_num': 1, 'image_channel': 3, 'learn': {'learner': {'hook': {'save_ckpt_after_iter': 10000}}}, 'model_type': 'conv', 'norm_type': 'BN', 'num_channels': 64, 'num_res_blocks': 1, 'observation_shape': (3, 64, 64), 'res_connection_in_dynamics': True, 'self_supervised_learning_loss': True, 'support_scale': 50, 'world_model_cfg': {'action_space_size': 6, 'analysis_dormant_ratio': False, 'analysis_sim_norm': False, 'attention': 'causal', 'attn_pdrop': 0.1, 'context_length': 8, 'continuous_action_space': False, 'device': 'cpu', 'dormant_threshold': 0.025, 'embed_dim': 768, 'embed_pdrop': 0.1, 'env_num': 8, 'gamma': 1, 'group_size': 8, 'gru_gating': False, 'latent_recon_loss_weight': 0.0, 'max_blocks': 10, 'max_cache_size': 5000, 'max_tokens': 20, 'num_heads': 8, 'num_layers': 2, 'obs_type': 'image', 'perceptual_loss_weight': 0.0, 'policy_entropy_weight': 0, 'predict_latent_loss_type': 'group_kl', 'resid_pdrop': 0.1, 'support_size': 101, 'tokens_per_block': 2}}, 'momentum': 0.9, 'monitor_extra_statistics': True, 'multi_gpu': False, 'n_episode': 8, 'num_segments': 8, 'num_simulations': 50, 'num_unroll_steps': 10, 'optim_type': 'AdamW', 'piecewise_decay_lr_scheduler': False, 'policy_loss_weight': 1, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'random_collect_episode_num': 0, 'replay_ratio': 0.25, 'reward_loss_weight': 1, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'sample_type': 'transition', 'sampled_algo': False, 'ssl_loss_weight': 0, 'target_update_freq': 100, 'target_update_freq_for_intrinsic_reward': 1000, 'target_update_theta': 0.05, 'td_steps': 5, 'threshold_training_steps_for_final_lr': 50000, 'threshold_training_steps_for_final_temperature': 50000, 'train_start_after_envsteps': 0, 'transform2string': False, 'type': 'unizero', 'update_per_collect': None, 'use_augmentation': False, 'use_priority': False, 'use_rnd_model': False, 'use_ture_chance_label_in_chance_encoder': False, 'value_loss_weight': 0.25, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
- model name and model import_names.
model_type (
str
): The model type used in this algorithm, which is registered in ModelRegistry.import_names (
List[str]
): The model class path list used in this algorithm.
- Return type:
model_info (
Tuple[str, List[str]]
)
Note
The user can define and use customized network model but must obey the same interface definition indicated by import_names path. For MuZero,
lzero.model.unizero_model.MuZeroModel
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- recompute_pos_emb_diff_and_clear_cache() None [source]
- Overview:
Clear the caches and precompute positional embedding matrices in the model.
- set_train_iter_env_step(train_iter, env_step) None
- Overview:
Set the train_iter and env_step for the policy.
- Parameters:
train_iter (-) – The train_iter for the policy.
env_step (-) – The env_step for the policy.
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}