Policy
AlphaZeroPolicy
- class lzero.policy.alphazero.AlphaZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
Policy
- Overview:
The policy class for AlphaZero.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(obs: Dict, temperature: float = 1) Dict[str, Tensor] [source]
- Overview:
The forward function for collecting data in collect mode. Use real env to execute MCTS search.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
temperature (-) – The temperature for MCTS search.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_eval(obs: Dict) Dict[str, Tensor] [source]
- Overview:
The forward function for evaluating the current policy in eval mode, similar to
self._forward_collect
.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_learn(inputs: Dict[str, Tensor]) Dict[str, float] [source]
- Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means that the policy inputs some training batch data from the replay buffer and then returns the output result, including various training information such as loss value, policy entropy, q value, priority, and so on. This method is left to be implemented by the subclass, and more arguments can be added in
data
item if necessary.
- Parameters:
data (-) – The input data used for policy forward, including a batch of training samples. For each element in list, the key of the dict is the name of data items and the value is the corresponding data. Usually, in the
_forward_learn
method, data should be stacked in the batch dimension by some utility functions such asdefault_preprocess_learn
.- Returns:
The training information of policy forward, including some metrics for monitoring training such as loss, priority, q value, policy entropy, and some data for next step training such as priority. Note the output data item should be Python native scalar rather than PyTorch tensor, which is convenient for the outside to use.
- Return type:
output (
Dict[int, Any]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Initialize the learn mode of policy, including related attributes and modules. This method will be called in
__init__
method iflearn
field is inenable_field
. Almost different policies have its own learn mode, so this method must be overrided in subclass.
Note
For the member variables that need to be saved and loaded, please refer to the
_state_dict_learn
and_load_state_dict_learn
methods.Note
For the member variables that need to be monitored, please refer to the
_monitor_vars_learn
method.Note
If you want to set some spacial member variables in
_init_learn
method, you’d better name them with prefix_learn_
to avoid conflict with other modes, such asself._learn_attr1
.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs: Dict, model_output: Dict[str, Tensor], timestep: namedtuple) Dict [source]
- Overview:
Generate the dict type transition (one timestep) data from policy learning.
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in collecting indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in evaluation indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any]
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'batch_size': 256, 'collector_env_num': 8, 'cuda': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'grad_clip_value': 10, 'gumbel_algo': False, 'learning_rate': 0.2, 'manual_temperature_decay': False, 'mcts': {'max_moves': 512, 'num_simulations': 50, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25}, 'model': {'num_channels': 32, 'num_res_blocks': 1, 'observation_shape': (3, 6, 6)}, 'momentum': 0.9, 'multi_gpu': False, 'optim_type': 'SGD', 'other': {'replay_buffer': {'replay_buffer_size': 1000000, 'save_episode': False}}, 'piecewise_decay_lr_scheduler': True, 'replay_ratio': 0.25, 'sampled_algo': False, 'tensor_float_32': False, 'threshold_training_steps_for_final_lr': 500000, 'threshold_training_steps_for_final_temperature': 100000, 'torch_compile': False, 'update_per_collect': None, 'value_weight': 1.0, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
The model type used in this algorithm, which is registered in ModelRegistry. - import_names (
List[str]
): The model class path list used in this algorithm.- Return type:
model_type (
str
)
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
MuZeroPolicy
EfficientZeroPolicy
Gumbel AlphaZeroPolicy
- class lzero.policy.gumbel_alphazero.GumbelAlphaZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
Policy
- Overview:
The policy class for GumbelAlphaZero.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(obs: Dict, temperature: float = 1) Dict[str, Tensor] [source]
- Overview:
The forward function for collecting data in collect mode. Use real env to execute MCTS search.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
temperature (-) – The temperature for MCTS search.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_eval(obs: Dict) Dict[str, Tensor] [source]
- Overview:
The forward function for evaluating the current policy in eval mode, similar to
self._forward_collect
.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_learn(inputs: Dict[str, Tensor]) Dict[str, float] [source]
- Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means that the policy inputs some training batch data from the replay buffer and then returns the output result, including various training information such as loss value, policy entropy, q value, priority, and so on. This method is left to be implemented by the subclass, and more arguments can be added in
data
item if necessary.
- Parameters:
data (-) – The input data used for policy forward, including a batch of training samples. For each element in list, the key of the dict is the name of data items and the value is the corresponding data. Usually, in the
_forward_learn
method, data should be stacked in the batch dimension by some utility functions such asdefault_preprocess_learn
.- Returns:
The training information of policy forward, including some metrics for monitoring training such as loss, priority, q value, policy entropy, and some data for next step training such as priority. Note the output data item should be Python native scalar rather than PyTorch tensor, which is convenient for the outside to use.
- Return type:
output (
Dict[int, Any]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Initialize the learn mode of policy, including related attributes and modules. This method will be called in
__init__
method iflearn
field is inenable_field
. Almost different policies have its own learn mode, so this method must be overrided in subclass.
Note
For the member variables that need to be saved and loaded, please refer to the
_state_dict_learn
and_load_state_dict_learn
methods.Note
For the member variables that need to be monitored, please refer to the
_monitor_vars_learn
method.Note
If you want to set some spacial member variables in
_init_learn
method, you’d better name them with prefix_learn_
to avoid conflict with other modes, such asself._learn_attr1
.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs: Dict, model_output: Dict[str, Tensor], timestep: namedtuple) Dict [source]
- Overview:
Generate the dict type transition (one timestep) data from policy learning.
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in collecting indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in evaluation indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any]
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'batch_size': 256, 'collector_env_num': 8, 'cuda': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'grad_clip_value': 10, 'learning_rate': 0.2, 'manual_temperature_decay': False, 'mcts': {'action_space_size': 9, 'continuous_action_space': False, 'gumbel_rng': 0.0, 'gumbel_scale': 10.0, 'legal_actions': None, 'max_moves': 512, 'max_num_considered_actions': 6, 'maxvisit_init': 50, 'num_of_sampled_actions': 2, 'num_simulations': 50, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25, 'value_scale': 0.1}, 'mcts_ctree': True, 'model': {'num_channels': 32, 'num_res_blocks': 1, 'observation_shape': (3, 6, 6)}, 'momentum': 0.9, 'optim_type': 'SGD', 'other': {'replay_buffer': {'replay_buffer_size': 1000000, 'save_episode': False}}, 'piecewise_decay_lr_scheduler': True, 'replay_ratio': 0.25, 'sampled_algo': False, 'tensor_float_32': False, 'threshold_training_steps_for_final_lr': 500000, 'threshold_training_steps_for_final_temperature': 100000, 'torch_compile': False, 'type': 'gumbel_alphazero', 'update_per_collect': None, 'value_weight': 1.0, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
The model type used in this algorithm, which is registered in ModelRegistry. - import_names (
List[str]
): The model class path list used in this algorithm.- Return type:
model_type (
str
)
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}
Gumbel MuZeroPolicy
Sampled AlphaZeroPolicy
- class lzero.policy.sampled_alphazero.SampledAlphaZeroPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Bases:
Policy
- Overview:
The policy class for Sampled AlphaZero.
- __init__(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None) None
- Overview:
Initialize policy instance according to input configures and model. This method will initialize differnent fields in policy, including
learn
,collect
,eval
. Thelearn
field is used to train the policy, thecollect
field is used to collect data for training, and theeval
field is used to evaluate the policy. Theenable_field
is used to specify which field to initialize, if it is None, then all fields will be initialized.
- Parameters:
cfg (-) – The final merged config used to initialize policy. For the default config, see the
config
attribute and its comments of policy class.model (-) – The neural network model used to initialize policy. If it is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be set to themodel
instance created by outside caller.enable_field (-) – The field list to initialize. If it is None, then all fields will be initialized. Otherwise, only the fields in
enable_field
will be initialized, which is beneficial to save resources.
Note
For the derived policy class, it should implement the
_init_learn
,_init_collect
,_init_eval
method to initialize the corresponding field.
- _abc_impl = <_abc._abc_data object>
- _calculate_policy_loss_disc(policy_probs: Tensor, target_policy: Tensor, target_sampled_actions: Tensor, valid_action_lengths: Tensor) Tensor [source]
- _create_model(cfg: EasyDict, model: Module | None = None) Module
- Overview:
Create or validate the neural network model according to the input configuration and model. If the input model is None, then the model will be created according to
default_model
method andcfg.model
field. Otherwise, the model will be verified as an instance oftorch.nn.Module
and set to themodel
instance created by outside caller.
- Parameters:
cfg (-) – The final merged config used to initialize policy.
model (-) – The neural network model used to initialize policy. User can refer to the default model defined in the corresponding policy to customize its own model.
- Returns:
The created neural network model. The different modes of policy will add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
- Return type:
model (
torch.nn.Module
)
- Raises:
- RuntimeError – If the input model is not None and is not an instance of
torch.nn.Module
.
- _forward_collect(obs: Dict, temperature: float = 1) Dict[str, Tensor] [source]
- Overview:
The forward function for collecting data in collect mode. Use real env to execute MCTS search.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
temperature (-) – The temperature for MCTS search.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_eval(obs: Dict) Dict[str, Tensor] [source]
- Overview:
The forward function for evaluating the current policy in eval mode, similar to
self._forward_collect
.
- Parameters:
obs (-) – The dict of obs, the key is env_id and the value is the corresponding obs in this timestep.
- Returns:
The dict of output, the key is env_id and the value is the the corresponding policy output in this timestep, including action, probs and so on.
- Return type:
output (
Dict[str, torch.Tensor]
)
- _forward_learn(inputs: Dict[str, Tensor]) Dict[str, float] [source]
- Overview:
Policy forward function of learn mode (training policy and updating parameters). Forward means that the policy inputs some training batch data from the replay buffer and then returns the output result, including various training information such as loss value, policy entropy, q value, priority, and so on. This method is left to be implemented by the subclass, and more arguments can be added in
data
item if necessary.
- Parameters:
data (-) – The input data used for policy forward, including a batch of training samples. For each element in list, the key of the dict is the name of data items and the value is the corresponding data. Usually, in the
_forward_learn
method, data should be stacked in the batch dimension by some utility functions such asdefault_preprocess_learn
.- Returns:
The training information of policy forward, including some metrics for monitoring training such as loss, priority, q value, policy entropy, and some data for next step training such as priority. Note the output data item should be Python native scalar rather than PyTorch tensor, which is convenient for the outside to use.
- Return type:
output (
Dict[int, Any]
)
- _get_attribute(name: str) Any
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to get the attribute of the policy in different modes.
- Parameters:
name (-) – The name of the attribute.
- Returns:
The value of the attribute.
- Return type:
value (
Any
)
Note
DI-engine’s policy will first try to access _get_{name} method, and then try to access _{name} attribute. If both of them are not found, it will raise a
NotImplementedError
.
- _get_batch_size() int | Dict[str, int]
- _get_n_episode() int | None
- _get_n_sample() int | None
- _get_train_sample(data)[source]
- Overview:
For a given trajectory (transitions, a list of transition) data, process it into a list of sample that can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary RL data preprocessing before training, which can help learner amortize revelant time consumption. In addition, you can also implement this method as an identity function and do the data processing in
self._forward_learn
method.
- Parameters:
transitions (-) – The trajectory data (a list of transition), each element is the same format as the return value of
self._process_transition
method.- Returns:
The processed train samples, each element is the similar format as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
- Return type:
samples (
List[Dict[str, Any]]
)
Note
We will vectorize
process_transition
andget_train_sample
method in the following release version. And the user can customize the this data processing procecure by overriding this two methods and collector itself
- _init_collect() None [source]
- Overview:
Collect mode init method. Called by
self.__init__
. Initialize the collect model and MCTS utils.
- _init_eval() None [source]
- Overview:
Evaluate mode init method. Called by
self.__init__
. Initialize the eval model and MCTS utils.
- _init_learn() None [source]
- Overview:
Initialize the learn mode of policy, including related attributes and modules. This method will be called in
__init__
method iflearn
field is inenable_field
. Almost different policies have its own learn mode, so this method must be overrided in subclass.
Note
For the member variables that need to be saved and loaded, please refer to the
_state_dict_learn
and_load_state_dict_learn
methods.Note
For the member variables that need to be monitored, please refer to the
_monitor_vars_learn
method.Note
If you want to set some spacial member variables in
_init_learn
method, you’d better name them with prefix_learn_
to avoid conflict with other modes, such asself._learn_attr1
.
- _init_multi_gpu_setting(model: Module, bp_update_sync: bool) None
- Overview:
Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning of the training, and prepare the hook function to allreduce the gradients of model parameters.
- Parameters:
model (-) – The neural network model to be trained.
bp_update_sync (-) – Whether to synchronize update the model parameters after allreduce the gradients of model parameters. Async update can be parallel in different network layers like pipeline so that it can save time.
- _load_state_dict_collect(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy collect state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_eval(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy eval mode, such as load auto-recover checkpoint, or model replica from learner in distributed training scenarios.
- Parameters:
state_dict (-) – The dict of policy eval state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _load_state_dict_learn(state_dict: Dict[str, Any]) None
- Overview:
Load the state_dict variable into policy learn mode.
- Parameters:
state_dict (-) – The dict of policy learn state saved before.
Tip
If you want to only load some parts of model, you can simply set the
strict
argument in load_state_dict toFalse
, or refer toding.torch_utils.checkpoint_helper
for more complicated operation.
- _monitor_vars_learn() List[str] [source]
- Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value
_forward_learn
.
- _process_transition(obs: Dict, model_output: Dict[str, Tensor], timestep: namedtuple) Dict [source]
- Overview:
Generate the dict type transition (one timestep) data from policy learning.
- _reset_collect(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in collecting indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_eval(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different environments/episodes in evaluation indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _reset_learn(data_id: List[int] | None = None) None
- Overview:
Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the memory bank of some special algortihms. If
data_id
is None, it means to reset all the stateful varaibles. Otherwise, it will reset the stateful variables according to thedata_id
. For example, different trajectories indata_id
will have different hidden state in RNN.
- Parameters:
data_id (-) – The id of the data, which is used to reset the stateful variables specified by
data_id
.
Note
This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
- _set_attribute(name: str, value: Any) None
- Overview:
In order to control the access of the policy attributes, we expose different modes to outside rather than directly use the policy instance. And we also provide a method to set the attribute of the policy in different modes. And the new attribute will name as
_{name}
.
- Parameters:
name (-) – The name of the attribute.
value (-) – The value of the attribute.
- _state_dict_collect() Dict[str, Any]
- Overview:
Return the state_dict of collect mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover collectors.
- Returns:
The dict of current policy collect state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed collector and renew a new one.
- _state_dict_eval() Dict[str, Any]
- Overview:
Return the state_dict of eval mode, only including model in usual, which is necessary for distributed training scenarios to auto-recover evaluators.
- Returns:
The dict of current policy eval state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
Tip
Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed evaluator and renew a new one.
- _state_dict_learn() Dict[str, Any]
- Overview:
Return the state_dict of learn mode, usually including model and optimizer.
- Returns:
The dict of current policy learn state, for saving and restoring.
- Return type:
state_dict (
Dict[str, Any]
)
- property cfg: EasyDict
- class collect_function(forward, process_transition, get_train_sample, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'process_transition', 'get_train_sample', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new collect_function object from a sequence or iterable
- _replace(**kwds)
Return a new collect_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- get_train_sample
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 7
- process_transition
Alias for field number 1
- reset
Alias for field number 3
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property collect_mode: collect_function
- Overview:
Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own collect mode.
- Returns:
The interfaces of collect mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.collect_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_collect = policy.collect_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_collect.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- config = {'batch_size': 256, 'collector_env_num': 8, 'cuda': False, 'evaluator_env_num': 3, 'fixed_temperature_value': 0.25, 'grad_clip_value': 10, 'learning_rate': 0.2, 'manual_temperature_decay': False, 'mcts': {'action_space_size': 9, 'continuous_action_space': False, 'legal_actions': None, 'max_moves': 512, 'num_of_sampled_actions': 2, 'num_simulations': 50, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'root_dirichlet_alpha': 0.3, 'root_noise_weight': 0.25}, 'mcts_ctree': True, 'model': {'num_channels': 32, 'num_res_blocks': 1, 'observation_shape': (3, 6, 6)}, 'momentum': 0.9, 'normalize_prob_of_sampled_actions': False, 'optim_type': 'SGD', 'other': {'replay_buffer': {'replay_buffer_size': 1000000, 'save_episode': False}}, 'piecewise_decay_lr_scheduler': True, 'policy_loss_type': 'cross_entropy', 'replay_ratio': 0.25, 'sampled_algo': False, 'tensor_float_32': False, 'threshold_training_steps_for_final_lr': 500000, 'threshold_training_steps_for_final_temperature': 100000, 'torch_compile': False, 'type': 'alphazero', 'update_per_collect': None, 'value_weight': 1.0, 'weight_decay': 0.0001}
- classmethod default_config() EasyDict
- Overview:
Get the default config of policy. This method is used to create the default config of policy.
- Returns:
The default config of corresponding policy. For the derived policy class, it will recursively merge the default config of base class and its own default config.
- Return type:
cfg (
EasyDict
)
Tip
This method will deepcopy the
config
attribute of the class and return the result. So users don’t need to worry about the modification of the returned config.
- default_model() Tuple[str, List[str]] [source]
- Overview:
Return this algorithm default model setting for demonstration.
- Returns:
The model type used in this algorithm, which is registered in ModelRegistry. - import_names (
List[str]
): The model class path list used in this algorithm.- Return type:
model_type (
str
)
- class eval_function(forward, reset, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new eval_function object from a sequence or iterable
- _replace(**kwds)
Return a new eval_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 2
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- load_state_dict
Alias for field number 5
- reset
Alias for field number 1
- set_attribute
Alias for field number 3
- state_dict
Alias for field number 4
- property eval_mode: eval_function
- Overview:
Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived subclass can override the interfaces to customize its own eval mode.
- Returns:
The interfaces of eval mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.eval_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_eval = policy.eval_mode >>> obs = env_manager.ready_obs >>> inference_output = policy_eval.forward(obs) >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
- class learn_function(forward, reset, info, monitor_vars, get_attribute, set_attribute, state_dict, load_state_dict)
Bases:
tuple
- _asdict()
Return a new dict which maps field names to their values.
- _field_defaults = {}
- _fields = ('forward', 'reset', 'info', 'monitor_vars', 'get_attribute', 'set_attribute', 'state_dict', 'load_state_dict')
- classmethod _make(iterable)
Make a new learn_function object from a sequence or iterable
- _replace(**kwds)
Return a new learn_function object replacing specified fields with new values
- count(value, /)
Return number of occurrences of value.
- forward
Alias for field number 0
- get_attribute
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- info
Alias for field number 2
- load_state_dict
Alias for field number 7
- monitor_vars
Alias for field number 3
- reset
Alias for field number 1
- set_attribute
Alias for field number 5
- state_dict
Alias for field number 6
- property learn_mode: learn_function
- Overview:
Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple to define immutable interfaces and restrict the usage of policy in different modes. Moreover, derived subclass can override the interfaces to customize its own learn mode.
- Returns:
The interfaces of learn mode of policy, it is a namedtuple whose values of distinct fields are different internal methods.
- Return type:
interfaces (
Policy.learn_function
)
Examples
>>> policy = Policy(cfg, model) >>> policy_learn = policy.learn_mode >>> train_output = policy_learn.forward(data) >>> state_dict = policy_learn.state_dict()
- sync_gradients(model: Module) None
- Overview:
Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
- Parameters:
model (-) – The model to synchronize gradients.
Note
This method is only used in multi-gpu training, and it should be called after
backward
method and beforestep
method. The user can also usebp_update_sync
config to control whether to synchronize gradients allreduce and optimizer updates.
- total_field = {'collect', 'eval', 'learn'}