MuZeroBuffer
- class lzero.mcts.buffer.game_buffer_muzero.MuZeroGameBuffer(cfg: dict)[source]
Bases:
GameBuffer
- Overview:
The specific game buffer for MuZero policy.
- _abc_impl = <_abc._abc_data object>
- _compute_target_policy_non_reanalyzed(policy_non_re_context: List[Any], policy_shape: int | None) ndarray [source]
- Overview:
prepare policy targets from the non-reanalyzed context of policies
- Parameters:
policy_non_re_context (-) – List containing: - pos_in_game_segment_list - child_visits - game_segment_lens - action_mask_segment - to_play_segment
policy_shape (-) – self._cfg.model.action_space_size
- Returns:
batch_target_policies_non_re
- _compute_target_policy_reanalyzed(policy_re_context: List[Any], model: Any) ndarray [source]
- Overview:
prepare policy targets from the reanalyzed context of policies
- Parameters:
policy_re_context (-) – List of policy context to reanalyzed
- Returns:
batch_target_policies_re
- _compute_target_reward_value(reward_value_context: List[Any], model: Any) Tuple[Any, Any] [source]
- Overview:
prepare reward and value targets from the context of rewards and values.
- Parameters:
( (- model) – obj:’list’): the reward value context
( – obj:’torch.tensor’):model of the target model
- Returns:
obj:’np.ndarray): batch of value prefix - batch_target_values (:obj:’np.ndarray): batch of value estimation
- Return type:
batch_value_prefixs (
- _make_batch(batch_size: int, reanalyze_ratio: float) Tuple[Any] [source]
- Overview:
first sample orig_data through
_sample_orig_data()
, then prepare the context of a batch:reward_value_context: the context of reanalyzed value targets policy_re_context: the context of reanalyzed policy targets policy_non_re_context: the context of non-reanalyzed policy targets current_batch: the inputs of batch
- Parameters:
batch_size (-) – the batch size of orig_data from replay buffer.
reanalyze_ratio (-) – ratio of reanalyzed policy (value is 100% reanalyzed)
- Returns:
reward_value_context, policy_re_context, policy_non_re_context, current_batch
- Return type:
context (
Tuple
)
- _make_batch_for_reanalyze(batch_size: int) Tuple[Any] [source]
- Overview:
first sample orig_data through
_sample_orig_data()
, then prepare the context of a batch:reward_value_context: the context of reanalyzed value targets policy_re_context: the context of reanalyzed policy targets policy_non_re_context: the context of non-reanalyzed policy targets current_batch: the inputs of batch
- Parameters:
batch_size (-) – the batch size of orig_data from replay buffer.
- Returns:
reward_value_context, policy_re_context, policy_non_re_context, current_batch
- Return type:
context (
Tuple
)
- _prepare_policy_non_reanalyzed_context(batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int]) List[Any] [source]
- Overview:
prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play
- Parameters:
batch_index_list (-) – the index of start transition of sampled minibatch in replay buffer
game_segment_list (-) – list of game segments
pos_in_game_segment_list (-) – list transition index in game
- Returns:
pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment
- Return type:
policy_non_re_context (
list
)
- _prepare_policy_reanalyzed_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str]) List[Any] [source]
- Overview:
prepare the context of policies for calculating policy target in reanalyzing part.
- Parameters:
( (- pos_in_game_segment_list) – obj:’list’): start transition index in the replay buffer
( – obj:’list’): list of game segments
( – obj:’list’): position of transition index in one game history
- Returns:
- policy_obs_list, policy_mask, pos_in_game_segment_list, indices,
child_visits, game_segment_lens, action_mask_segment, to_play_segment
- Return type:
policy_re_context (
list
)
- _prepare_reward_value_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], total_transitions: int) List[Any] [source]
- Overview:
prepare the context of rewards and values for calculating TD value target in reanalyzing part.
- Parameters:
batch_index_list (-) – the index of start transition of sampled minibatch in replay buffer
game_segment_list (-) – list of game segments
pos_in_game_segment_list (-) – list of transition index in game_segment
total_transitions (-) – number of collected transitions
- Returns:
- value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
td_steps_list, action_mask_segment, to_play_segment
- Return type:
reward_value_context (
list
)
- _preprocess_to_play_and_action_mask(game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps=None)
- Overview:
- prepare the to_play and action_mask for the target obs in
value_obs_list
to_play: {list: game_segment_batch_size * (num_unroll_steps+1)}
action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)}
- prepare the to_play and action_mask for the target obs in
- _push_game_segment(data: Any, meta: dict | None = None) None
- Overview:
Push data and it’s meta information in buffer. Save a game segment.
- Parameters:
data (-) – The data (a game segment) which will be pushed into buffer.
meta (-) – Meta information, e.g. priority, count, staleness. - done (
bool
): whether the game is finished. - unroll_plus_td_steps (int
): if the game is not finished, we only save the transitions that can be computed - priorities (list
): the priorities corresponding to the transitions in the game history
- Returns:
The pushed data.
- Return type:
buffered_data (
BufferedData
)
- _remove(excess_game_segment_index: List[int]) None
- Overview:
delete game segments in index [0: excess_game_segment_index]
- Parameters:
excess_game_segment_index (-) – Index of data.
- _sample_orig_data(batch_size: int) Tuple
- Overview:
- sample orig_data that contains:
game_segment_list: a list of game segments pos_in_game_segment_list: transition index in game (relative index) batch_index_list: the index of start transition of sampled minibatch in replay buffer weights_list: the weight concerning the priority make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)
- Parameters:
batch_size (-) – batch size
beta (-) – float the parameter in PER for calculating the priority
- _sample_orig_data_episode(batch_size: int) Tuple
- Overview:
- Sample original data for a training batch, which includes:
game_segment_list: A list of game segments.
pos_in_game_segment_list: Indices of transitions within the game segments.
batch_index_list: Indices of the start transitions of the sampled mini-batch in the replay buffer.
weights_list: Weights for each sampled transition, used for prioritization.
make_time: Timestamps indicating when the batch was created (useful for managing replay buffer updates).
- Parameters:
batch_size (-) – The number of samples to draw for the batch.
beta (-) – Parameter for Prioritized Experience Replay (PER) that adjusts the importance of samples.
- _sample_orig_reanalyze_batch(batch_size: int) Tuple
- Overview:
This function samples a batch of game segments for reanalysis from the replay buffer. It uses priority sampling based on the reanalyze_time of each game segment, with segments that have been reanalyzed more frequently receiving lower priority.
The function returns a tuple containing information about the sampled game segments, including their positions within each segment and the time the batch was created.
- Parameters:
batch_size (-) – The number of samples to draw in this batch.
- Returns:
A tuple containing the following elements: - game_segment_list: A list of the sampled game segments. - pos_in_game_segment_list: A list of indices representing the position of each transition
within its corresponding game segment.
batch_index_list: The indices of the sampled game segments in the replay buffer.
make_time: A list of timestamps (set to 0 in this implementation) indicating when the batch was created.
- Return type:
Tuple
- Key Details:
Priority Sampling: Game segments are sampled based on a probability distribution calculated using the reanalyze_time of each segment. Segments that have been reanalyzed more frequently are less likely to be selected.
Segment Slicing: Each selected game segment is sampled at regular intervals determined by the num_unroll_steps parameter. Up to samples_per_segment transitions are sampled from each selected segment.
Handling Extra Samples: If the batch_size is not perfectly divisible by the number of samples per segment, additional segments are sampled to make up the difference.
Reanalyze Time Update: The reanalyze_time attribute of each sampled game segment is incremented to reflect that it has been selected for reanalysis again.
- Raises:
- ValueError – If the game_segment_length is too small to accommodate the num_unroll_steps.
- _sample_orig_reanalyze_data(batch_size: int) Tuple
- Overview:
- sample orig_data that contains:
game_segment_list: a list of game segments pos_in_game_segment_list: transition index in game (relative index) batch_index_list: the index of start transition of sampled minibatch in replay buffer weights_list: the weight concerning the priority make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)
- Parameters:
batch_size (-) – batch size
beta (-) – float the parameter in PER for calculating the priority
- config = {'mini_infer_size': 10240, 'reanalyze_outdated': True, 'reanalyze_ratio': 0, 'replay_buffer_size': 1000000, 'sample_type': 'transition', 'use_root_value': False}
- classmethod default_config() EasyDict
- get_num_of_episodes() int
- get_num_of_game_segments() int
- get_num_of_transitions() int
- push_game_segments(data_and_meta: Any) None
- Overview:
Push game_segments data and it’s meta information into buffer. Save a game segment
- Parameters:
data_and_meta (-) –
data (
Any
): The data (game segments) which will be pushed into buffer.meta (
dict
): Meta information, e.g. priority, count, staleness.
- reanalyze_buffer(batch_size: int, policy: MuZeroPolicy | EfficientZeroPolicy | SampledEfficientZeroPolicy) List[Any] [source]
- Overview:
sample data from
GameBuffer
and prepare the current and target batch for training.
- Parameters:
batch_size (-) – batch size.
policy (-) – policy.
- Returns:
List of train data, including current_batch and target_batch.
- Return type:
train_data (
List
)
- remove_oldest_data_to_fit() None
- Overview:
remove some oldest data if the replay buffer is full.
- sample(batch_size: int, policy: MuZeroPolicy | EfficientZeroPolicy | SampledEfficientZeroPolicy) List[Any] [source]
- Overview:
sample data from
GameBuffer
and prepare the current and target batch for training.
- Parameters:
batch_size (-) – batch size.
policy (-) – policy.
- Returns:
List of train data, including current_batch and target_batch.
- Return type:
train_data (
List
)
- update_priority(train_data: List[ndarray], batch_priorities: Any) None [source]
- Overview:
Update the priority of training data.
- Parameters:
train_data (-) – training data to be updated priority.
batch_priorities (-) – priorities to update to.
Note
train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list]