MuZeroBuffer

class lzero.mcts.buffer.game_buffer_muzero.MuZeroGameBuffer(cfg: dict)[source]

Bases: GameBuffer

Overview:

The specific game buffer for MuZero policy.

__init__(cfg: dict)[source]
_abc_impl = <_abc._abc_data object>
_compute_target_policy_non_reanalyzed(policy_non_re_context: List[Any], policy_shape: int | None) ndarray[source]
Overview:

prepare policy targets from the non-reanalyzed context of policies

Parameters:
  • policy_non_re_context (-) – List containing: - pos_in_game_segment_list - child_visits - game_segment_lens - action_mask_segment - to_play_segment

  • policy_shape (-) – self._cfg.model.action_space_size

Returns:

  • batch_target_policies_non_re

_compute_target_policy_reanalyzed(policy_re_context: List[Any], model: Any) ndarray[source]
Overview:

prepare policy targets from the reanalyzed context of policies

Parameters:

policy_re_context (-) – List of policy context to reanalyzed

Returns:

  • batch_target_policies_re

_compute_target_reward_value(reward_value_context: List[Any], model: Any) Tuple[Any, Any][source]
Overview:

prepare reward and value targets from the context of rewards and values.

Parameters:
  • ( (- model) – obj:’list’): the reward value context

  • ( – obj:’torch.tensor’):model of the target model

Returns:

obj:’np.ndarray): batch of value prefix - batch_target_values (:obj:’np.ndarray): batch of value estimation

Return type:

  • batch_value_prefixs (

_make_batch(batch_size: int, reanalyze_ratio: float) Tuple[Any][source]
Overview:

first sample orig_data through _sample_orig_data(), then prepare the context of a batch:

reward_value_context: the context of reanalyzed value targets policy_re_context: the context of reanalyzed policy targets policy_non_re_context: the context of non-reanalyzed policy targets current_batch: the inputs of batch

Parameters:
  • batch_size (-) – the batch size of orig_data from replay buffer.

  • reanalyze_ratio (-) – ratio of reanalyzed policy (value is 100% reanalyzed)

Returns:

reward_value_context, policy_re_context, policy_non_re_context, current_batch

Return type:

  • context (Tuple)

_prepare_policy_non_reanalyzed_context(batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int]) List[Any][source]
Overview:

prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play

Parameters:
  • batch_index_list (-) – the index of start transition of sampled minibatch in replay buffer

  • game_segment_list (-) – list of game segments

  • pos_in_game_segment_list (-) – list transition index in game

Returns:

pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment

Return type:

  • policy_non_re_context (list)

_prepare_policy_reanalyzed_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str]) List[Any][source]
Overview:

prepare the context of policies for calculating policy target in reanalyzing part.

Parameters:
  • ( (- pos_in_game_segment_list) – obj:’list’): start transition index in the replay buffer

  • ( – obj:’list’): list of game segments

  • ( – obj:’list’): position of transition index in one game history

Returns:

policy_obs_list, policy_mask, pos_in_game_segment_list, indices,

child_visits, game_segment_lens, action_mask_segment, to_play_segment

Return type:

  • policy_re_context (list)

_prepare_reward_value_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], total_transitions: int) List[Any][source]
Overview:

prepare the context of rewards and values for calculating TD value target in reanalyzing part.

Parameters:
  • batch_index_list (-) – the index of start transition of sampled minibatch in replay buffer

  • game_segment_list (-) – list of game segments

  • pos_in_game_segment_list (-) – list of transition index in game_segment

  • total_transitions (-) – number of collected transitions

Returns:

value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,

td_steps_list, action_mask_segment, to_play_segment

Return type:

  • reward_value_context (list)

_preprocess_to_play_and_action_mask(game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps=None)
Overview:
prepare the to_play and action_mask for the target obs in value_obs_list
  • to_play: {list: game_segment_batch_size * (num_unroll_steps+1)}

  • action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)}

_push_game_segment(data: Any, meta: dict | None = None) None
Overview:

Push data and it’s meta information in buffer. Save a game segment.

Parameters:
  • data (-) – The data (a game segment) which will be pushed into buffer.

  • meta (-) – Meta information, e.g. priority, count, staleness. - done (bool): whether the game is finished. - unroll_plus_td_steps (int): if the game is not finished, we only save the transitions that can be computed - priorities (list): the priorities corresponding to the transitions in the game history

Returns:

The pushed data.

Return type:

  • buffered_data (BufferedData)

_remove(excess_game_segment_index: List[int]) None
Overview:

delete game segments in index [0: excess_game_segment_index]

Parameters:

excess_game_segment_index (-) – Index of data.

_sample_orig_data(batch_size: int) Tuple
Overview:
sample orig_data that contains:

game_segment_list: a list of game segments pos_in_game_segment_list: transition index in game (relative index) batch_index_list: the index of start transition of sampled minibatch in replay buffer weights_list: the weight concerning the priority make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)

Parameters:
  • batch_size (-) – batch size

  • beta (-) – float the parameter in PER for calculating the priority

_sample_orig_data_episode(batch_size: int) Tuple
Overview:
Sample original data for a training batch, which includes:
  • game_segment_list: A list of game segments.

  • pos_in_game_segment_list: Indices of transitions within the game segments.

  • batch_index_list: Indices of the start transitions of the sampled mini-batch in the replay buffer.

  • weights_list: Weights for each sampled transition, used for prioritization.

  • make_time: Timestamps indicating when the batch was created (useful for managing replay buffer updates).

Parameters:
  • batch_size (-) – The number of samples to draw for the batch.

  • beta (-) – Parameter for Prioritized Experience Replay (PER) that adjusts the importance of samples.

_sample_orig_reanalyze_data(batch_size: int) Tuple
Overview:
sample orig_data that contains:

game_segment_list: a list of game segments pos_in_game_segment_list: transition index in game (relative index) batch_index_list: the index of start transition of sampled minibatch in replay buffer weights_list: the weight concerning the priority make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)

Parameters:
  • batch_size (-) – batch size

  • beta (-) – float the parameter in PER for calculating the priority

config = {'mini_infer_size': 10240, 'reanalyze_outdated': True, 'reanalyze_ratio': 0, 'replay_buffer_size': 1000000, 'sample_type': 'transition', 'use_root_value': False}
classmethod default_config() EasyDict
get_num_of_episodes() int
get_num_of_game_segments() int
get_num_of_transitions() int
push_game_segments(data_and_meta: Any) None
Overview:

Push game_segments data and it’s meta information into buffer. Save a game segment

Parameters:

data_and_meta (-) –

  • data (Any): The data (game segments) which will be pushed into buffer.

  • meta (dict): Meta information, e.g. priority, count, staleness.

remove_oldest_data_to_fit() None
Overview:

remove some oldest data if the replay buffer is full.

reset_runtime_metrics()[source]
Overview:

Reset the runtime metrics of the buffer.

sample(batch_size: int, policy: MuZeroPolicy | EfficientZeroPolicy | SampledEfficientZeroPolicy) List[Any][source]
Overview:

sample data from GameBuffer and prepare the current and target batch for training.

Parameters:
  • batch_size (-) – batch size.

  • policy (-) – policy.

Returns:

List of train data, including current_batch and target_batch.

Return type:

  • train_data (List)

update_priority(train_data: List[ndarray], batch_priorities: Any) None[source]
Overview:

Update the priority of training data.

Parameters:
  • train_data (-) – training data to be updated priority.

  • batch_priorities (-) – priorities to update to.

Note

train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list]