import copy
from collections import namedtuple
from typing import List, Dict, Tuple
import numpy as np
import torch.distributions
import torch.nn.functional as F
import torch.optim as optim
from ding.policy.base_policy import Policy
from ding.torch_utils import to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate
from easydict import EasyDict
from lzero.policy import configure_optimizers
from lzero.policy.utils import pad_and_get_lengths, compute_entropy
[docs]@POLICY_REGISTRY.register('sampled_alphazero')
class SampledAlphaZeroPolicy(Policy):
"""
Overview:
The policy class for Sampled AlphaZero.
"""
# The default_config for AlphaZero policy.
config = dict(
# (str) The type of policy, as the key of the policy registry.
type='alphazero',
# (bool) Whether to enable the sampled-based algorithm (e.g. Sampled AlphaZero)
# this variable is used in ``collector``.
sampled_algo=False,
normalize_prob_of_sampled_actions=False,
policy_loss_type='cross_entropy', # options={'cross_entropy', 'KL'}
# (bool) Whether to use torch.compile method to speed up our model, which required torch>=2.0.
torch_compile=False,
# (bool) Whether to use TF32 for our model.
tensor_float_32=False,
model=dict(
# (tuple) The stacked obs shape.
observation_shape=(3, 6, 6),
# (int) The number of res blocks in AlphaZero model.
num_res_blocks=1,
# (int) The number of channels of hidden states in AlphaZero model.
num_channels=32,
),
# (bool) Whether to use C++ MCTS in policy. If False, use Python implementation.
mcts_ctree=True,
# (bool) Whether to use cuda for network.
cuda=False,
# (int) How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
# For different env, we have different episode_length,
# we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor.
# If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically.
update_per_collect=None,
# (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
replay_ratio=0.25,
# (int) Minibatch size for one gradient descent.
batch_size=256,
# (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW']
optim_type='SGD',
# (float) Learning rate for training policy network. Initial lr for manually decay schedule.
learning_rate=0.2,
# (float) Weight decay for training policy network.
weight_decay=1e-4,
# (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction).
momentum=0.9,
# (float) The maximum constraint value of gradient norm clipping.
grad_clip_value=10,
# (float) The weight of value loss.
value_weight=1.0,
# (int) The number of environments used in collecting data.
collector_env_num=8,
# (int) The number of environments used in evaluating policy.
evaluator_env_num=3,
# (bool) Whether to use piecewise constant learning rate decay.
# i.e. lr: 0.2 -> 0.02 -> 0.002
piecewise_decay_lr_scheduler=True,
# (int) The number of final training iterations to control lr decay, which is only used for manually decay.
threshold_training_steps_for_final_lr=int(5e5),
# (bool) Whether to use manually temperature decay.
# i.e. temperature: 1 -> 0.5 -> 0.25
manual_temperature_decay=False,
# (int) The number of final training iterations to control temperature, which is only used for manually decay.
threshold_training_steps_for_final_temperature=int(1e5),
# (float) The fixed temperature value for MCTS action selection, which is used to control the exploration.
# The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
fixed_temperature_value=0.25,
mcts=dict(
# (int) The number of simulations to perform at each move.
num_simulations=50,
# (int) The maximum number of moves to make in a game.
max_moves=512, # for chess and shogi, 722 for Go.
# (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree.
root_dirichlet_alpha=0.3,
# (float) The noise weight at the root node of the search tree.
root_noise_weight=0.25,
# (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search.
pb_c_base=19652,
# (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search.
pb_c_init=1.25,
#
legal_actions=None,
# (int) The action space size.
action_space_size=9,
# (int) The number of sampled actions for each state.
num_of_sampled_actions=2,
#
continuous_action_space=False,
),
other=dict(replay_buffer=dict(
replay_buffer_size=int(1e6),
save_episode=False,
)),
)
[docs] def default_model(self) -> Tuple[str, List[str]]:
"""
Overview:
Return this algorithm default model setting for demonstration.
Returns:
- model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry.
- import_names (:obj:`List[str]`): The model class path list used in this algorithm.
"""
return 'AlphaZeroModel', ['lzero.model.alphazero_model']
[docs] def _init_learn(self) -> None:
assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type
if self._cfg.optim_type == 'SGD':
self._optimizer = optim.SGD(
self._model.parameters(),
lr=self._cfg.learning_rate,
momentum=self._cfg.momentum,
weight_decay=self._cfg.weight_decay,
)
elif self._cfg.optim_type == 'Adam':
self._optimizer = optim.Adam(
self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay
)
elif self._cfg.optim_type == 'AdamW':
self._optimizer = configure_optimizers(
model=self._model,
weight_decay=self._cfg.weight_decay,
learning_rate=self._cfg.learning_rate,
device_type=self._cfg.device
)
if self._cfg.piecewise_decay_lr_scheduler:
from torch.optim.lr_scheduler import LambdaLR
max_step = self._cfg.threshold_training_steps_for_final_lr
# NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr.
# lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa
lr_lambda = lambda step: 1 if step < max_step * 0.33 else (0.1 if step < max_step * 0.66 else 0.01) # noqa
self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda)
# Algorithm config
self._value_weight = self._cfg.value_weight
self._entropy_weight = self._cfg.entropy_weight
# Main and target models
self._learn_model = self._model
# TODO(pu): test the effect of torch 2.0
if self._cfg.torch_compile:
self._learn_model = torch.compile(self._learn_model)
[docs] def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
for input_dict in inputs:
# Check and remove 'katago_game_state' from 'obs' if it exists
if 'katago_game_state' in input_dict['obs']:
del input_dict['obs']['katago_game_state']
# Check and remove 'katago_game_state' from 'next_obs' if it exists
if 'katago_game_state' in input_dict['next_obs']:
del input_dict['next_obs']['katago_game_state']
# list of dict -> dict of list
# only for env with variable legal actions
inputs = pad_and_get_lengths(inputs, self._cfg.mcts.num_of_sampled_actions)
inputs = default_collate(inputs)
valid_action_length = inputs['action_length']
if self._cuda:
inputs = to_device(inputs, self._device)
self._learn_model.train()
state_batch = inputs['obs']['observation']
mcts_visit_count_probs = inputs['probs']
reward = inputs['reward']
root_sampled_actions = inputs['root_sampled_actions']
if len(root_sampled_actions.shape) == 1:
print(f"root_sampled_actions.shape: {root_sampled_actions.shape}")
state_batch = state_batch.to(device=self._device, dtype=torch.float)
mcts_visit_count_probs = mcts_visit_count_probs.to(device=self._device, dtype=torch.float)
reward = reward.to(device=self._device, dtype=torch.float)
policy_probs, values = self._learn_model.compute_policy_value(state_batch)
# calculate policy entropy, for monitoring only
policy_entropy = -(policy_probs * policy_probs.log()).sum(-1).mean()
policy_entropy_loss = -policy_entropy
# ==============================================================
# policy loss
# ==============================================================
policy_loss = self._calculate_policy_loss_disc(policy_probs, mcts_visit_count_probs, root_sampled_actions,
valid_action_length)
# ==============================================================
# value loss
# ==============================================================
value_loss = F.mse_loss(values.view(-1), reward)
total_loss = self._value_weight * value_loss + policy_loss + self._entropy_weight * policy_entropy_loss
self._optimizer.zero_grad()
total_loss.backward()
total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(
list(self._model.parameters()),
max_norm=self._cfg.grad_clip_value,
)
self._optimizer.step()
if self._cfg.piecewise_decay_lr_scheduler is True:
self.lr_scheduler.step()
# =============
# after update
# =============
return {
'cur_lr': self._optimizer.param_groups[0]['lr'],
'total_loss': total_loss.item(),
'policy_loss': policy_loss.item(),
'value_loss': value_loss.item(),
'policy_entropy_loss': policy_entropy_loss.item(),
'total_grad_norm_before_clip': total_grad_norm_before_clip.item(),
'collect_mcts_temperature': self.collect_mcts_temperature,
}
[docs] def _calculate_policy_loss_disc(
self, policy_probs: torch.Tensor, target_policy: torch.Tensor,
target_sampled_actions: torch.Tensor, valid_action_lengths: torch.Tensor
) -> torch.Tensor:
# For each batch and each sampled action, get the corresponding probability
# from policy_probs and target_policy, and put it into sampled_policy_probs and
# sampled_target_policy at the same position.
sampled_policy_probs = policy_probs.gather(1, target_sampled_actions)
sampled_target_policy = target_policy.gather(1, target_sampled_actions)
# Create a mask for valid actions
max_length = target_sampled_actions.size(1)
mask = torch.arange(max_length).expand(len(valid_action_lengths), max_length) < valid_action_lengths.unsqueeze(
1)
mask = mask.to(device=self._device)
# Apply the mask to sampled_policy_probs and sampled_target_policy
sampled_policy_probs = sampled_policy_probs * mask.float()
sampled_target_policy = sampled_target_policy * mask.float()
# Normalize sampled_policy_probs and sampled_target_policy
sampled_policy_probs = sampled_policy_probs / (sampled_policy_probs.sum(dim=1, keepdim=True) + 1e-6)
sampled_target_policy = sampled_target_policy / (sampled_target_policy.sum(dim=1, keepdim=True) + 1e-6)
# after normalization, the sum of each row should be 1, but the prob corresponding to valid action becomes a small non-zero value
# Use torch.where to prevent gradients for invalid actions
sampled_policy_probs = torch.where(mask, sampled_policy_probs, torch.zeros_like(sampled_policy_probs))
sampled_target_policy = torch.where(mask, sampled_target_policy, torch.zeros_like(sampled_target_policy))
if self._cfg.policy_loss_type == 'KL':
# Calculate the KL divergence between sampled_policy_probs and sampled_target_policy
# The KL divergence between 2 probability distributions P and Q is defined as:
# KL(P || Q) = sum(P(i) * log(P(i) / Q(i)))
# We use the PyTorch function kl_div to calculate it.
loss = torch.nn.functional.kl_div(
sampled_policy_probs.log(), sampled_target_policy, reduction='none'
)
loss = torch.nan_to_num(loss)
# Apply the mask to the loss
loss = loss * mask.float()
# Calculate the mean loss over the batch
loss = loss.sum() / mask.sum()
elif self._cfg.policy_loss_type == 'cross_entropy':
# Calculate the cross entropy loss between sampled_policy_probs and sampled_target_policy
# The cross entropy between 2 probability distributions P and Q is defined as:
# H(P, Q) = -sum(P(i) * log(Q(i)))
# We use the PyTorch function cross_entropy to calculate it.
loss = torch.nn.functional.cross_entropy(
sampled_policy_probs, torch.argmax(sampled_target_policy, dim=1), reduction='none'
)
# 使用 nan_to_num 将 loss 中的 nan 值设置为0
loss = torch.nan_to_num(loss)
# Apply the mask to the loss
loss = loss * mask.float()
# Calculate the mean loss over the batch
loss = loss.sum() / mask.sum()
else:
raise ValueError(f"Invalid policy_loss_type: {self._cfg.policy_loss_type}")
return loss
[docs] def _init_collect(self) -> None:
"""
Overview:
Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils.
"""
self._get_simulation_env()
self._collect_model = self._model
if self._cfg.mcts_ctree:
import sys
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
import mcts_alphazero
self._collect_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, self._cfg.mcts.num_simulations,
self._cfg.mcts.pb_c_base,
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
self._cfg.mcts.root_noise_weight, self.simulate_env)
else:
if self._cfg.sampled_algo:
from lzero.mcts.ptree.ptree_az_sampled import MCTS
else:
from lzero.mcts.ptree.ptree_az import MCTS
self._collect_mcts = MCTS(self._cfg.mcts, self.simulate_env)
self.collect_mcts_temperature = 1
[docs] @torch.no_grad()
def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]:
"""
Overview:
The forward function for collecting data in collect mode. Use real env to execute MCTS search.
Arguments:
- obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \
corresponding obs in this timestep.
- temperature (:obj:`float`): The temperature for MCTS search.
Returns:
- output (:obj:`Dict[str, torch.Tensor]`): The dict of output, the key is env_id and the value is the \
the corresponding policy output in this timestep, including action, probs and so on.
"""
self.collect_mcts_temperature = temperature
ready_env_id = list(obs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
try:
katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id}
except Exception as e:
katago_game_state = {env_id: None for env_id in ready_env_id}
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
output = {}
self._policy_model = self._collect_model
for env_id in ready_env_id:
# print('[collect] start_player_index={}'.format(start_player_index[env_id]))
# print('[collect] init_state=\n{}'.format(init_state[env_id]))
state_config_for_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id],
katago_policy_init=True,
katago_game_state=katago_game_state[env_id]))
action, mcts_visit_count_probs = self._collect_mcts.get_next_action(
state_config_for_env_reset,
self._policy_value_func,
self.collect_mcts_temperature,
True,
)
# if np.array_equal(self._collect_mcts.get_sampled_actions(), np.array([2, 2, 3])):
# print('debug')
output[env_id] = {
'action': action,
'probs': mcts_visit_count_probs,
'root_sampled_actions': self._collect_mcts.get_sampled_actions(),
}
return output
[docs] def _init_eval(self) -> None:
"""
Overview:
Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils.
"""
self._get_simulation_env()
if self._cfg.mcts_ctree:
import sys
sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
import mcts_alphazero
# TODO(pu): how to set proper num_simulations for evaluation
self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves,
min(800, self._cfg.mcts.num_simulations * 4),
self._cfg.mcts.pb_c_base,
self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
self._cfg.mcts.root_noise_weight, self.simulate_env)
else:
if self._cfg.sampled_algo:
from lzero.mcts.ptree.ptree_az_sampled import MCTS
else:
from lzero.mcts.ptree.ptree_az import MCTS
mcts_eval_config = copy.deepcopy(self._cfg.mcts)
# TODO(pu): how to set proper num_simulations for evaluation
mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4)
self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)
self._eval_model = self._model
[docs] def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
"""
Overview:
The forward function for evaluating the current policy in eval mode, similar to ``self._forward_collect``.
Arguments:
- obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \
corresponding obs in this timestep.
Returns:
- output (:obj:`Dict[str, torch.Tensor]`): The dict of output, the key is env_id and the value is the \
the corresponding policy output in this timestep, including action, probs and so on.
"""
ready_env_id = list(obs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
try:
katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id}
except Exception as e:
katago_game_state = {env_id: None for env_id in ready_env_id}
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
output = {}
self._policy_model = self._eval_model
for env_id in ready_env_id:
# print('[eval] start_player_index={}'.format(start_player_index[env_id]))
# print('[eval] init_state=\n {}'.format(init_state[env_id]))
state_config_for_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
init_state=init_state[env_id],
katago_policy_init=False,
katago_game_state=katago_game_state[env_id]))
# try:
action, mcts_visit_count_probs = self._eval_mcts.get_next_action(state_config_for_env_reset,
self._policy_value_func,
1.0, False)
# except Exception as e:
# print(f"Exception occurred: {e}")
# print(f"Is self._policy_value_func callable? {callable(self._policy_value_func)}")
# raise # re-raise the exception
# print("="*20)
# print(action, mcts_visit_count_probs)
# print("="*20)
output[env_id] = {
'action': action,
'probs': mcts_visit_count_probs,
}
return output
[docs] def _get_simulation_env(self):
"""
Overview:
Create simulation environment for MCTS using registry-based approach.
Uses ENV_REGISTRY to instantiate environment based on simulation_env_id.
"""
from ding.utils import import_module, ENV_REGISTRY
# Import env modules to trigger registration
import_names = self._cfg.create_cfg.env.get('import_names', [])
import_module(import_names)
# Get env class from registry
env_cls = ENV_REGISTRY.get(self._cfg.simulation_env_id)
# Create simulation env with config from main config
self.simulate_env = env_cls(self._cfg.full_cfg.env)
[docs] @torch.no_grad()
def _policy_value_func(self, environment: 'Environment') -> Tuple[Dict[int, np.ndarray], float]:
# Retrieve the legal actions in the current environment
legal_actions = environment.legal_actions
# Retrieve the current state and its scale from the environment
current_state, state_scale = environment.current_state()
# Convert the state scale to a PyTorch FloatTensor, adding a dimension to match the model's input requirements
state_scale_tensor = torch.from_numpy(state_scale).to(
device=self._device, dtype=torch.float
).unsqueeze(0)
# Compute action probabilities and state value for the current state using the policy model, without gradient computation
with torch.no_grad():
action_probabilities, state_value = self._policy_model.compute_policy_value(state_scale_tensor)
# Extract the probabilities of the legal actions from the action probabilities, and convert the result to a numpy array
legal_action_probabilities = dict(
zip(legal_actions, action_probabilities.squeeze(0)[legal_actions].detach().cpu().numpy()))
# Return probabilities of the legal actions and the state value
return legal_action_probabilities, state_value.item()
[docs] def _monitor_vars_learn(self) -> List[str]:
"""
Overview:
Register the variables to be monitored in learn mode. The registered variables will be logged in
tensorboard according to the return value ``_forward_learn``.
"""
return super()._monitor_vars_learn() + [
'cur_lr', 'total_loss', 'policy_loss', 'value_loss', 'policy_entropy_loss', 'total_grad_norm_before_clip',
'collect_mcts_temperature'
]
[docs] def _process_transition(self, obs: Dict, model_output: Dict[str, torch.Tensor], timestep: namedtuple) -> Dict:
"""
Overview:
Generate the dict type transition (one timestep) data from policy learning.
"""
if 'katago_game_state' in obs.keys():
del obs['katago_game_state']
# if 'katago_game_state' in timestep.obs.keys():
# del timestep.obs['katago_game_state']
# Note: used in _foward_collect in alphazero_collector now
return {
'obs': obs,
'next_obs': timestep.obs,
'action': model_output['action'],
'root_sampled_actions': model_output['root_sampled_actions'],
'probs': model_output['probs'],
'reward': timestep.reward,
'done': timestep.done,
}
[docs] def _get_train_sample(self, data):
# be compatible with DI-engine Policy class
pass