Source code for lzero.worker.muzero_evaluator

import copy
import time
from collections import namedtuple
from typing import Optional, Callable, Tuple, Dict, Any

import numpy as np
import torch
from ding.envs import BaseEnvManager
from ding.torch_utils import to_ndarray, to_item, to_tensor
from ding.utils import build_logger, EasyTimer
from ding.utils import get_world_size, get_rank, broadcast_object_list
from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
from easydict import EasyDict

from lzero.mcts.buffer.game_segment import GameSegment
from lzero.mcts.utils import prepare_observation


[docs]class MuZeroEvaluator(ISerialEvaluator): """ Overview: The Evaluator class for MCTS+RL algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. Interfaces: __init__, reset, reset_policy, reset_env, close, should_eval, eval Properties: env, policy """
[docs] @classmethod def default_config(cls: type) -> EasyDict: """ Overview: Retrieve the default configuration for the evaluator by merging evaluator-specific defaults with other defaults and any user-provided configuration. Returns: - cfg (:obj:`EasyDict`): The default configuration for the evaluator. """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg
config = dict( # Evaluate every "eval_freq" training iterations. eval_freq=50, )
[docs] def __init__( self, eval_freq: int = 1000, n_evaluator_episode: int = 3, stop_value: int = 1e6, env: BaseEnvManager = None, policy: namedtuple = None, tb_logger: 'SummaryWriter' = None, # noqa exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'evaluator', policy_config: 'policy_config' = None, # noqa ) -> None: """ Overview: Initialize the evaluator with configuration settings for various components such as logger helper and timer. Arguments: - eval_freq (:obj:`int`): Evaluation frequency in terms of training steps. - n_evaluator_episode (:obj:`int`): Number of episodes to evaluate in total. - stop_value (:obj:`float`): A reward threshold above which the training is considered converged. - env (:obj:`Optional[BaseEnvManager]`): An optional instance of a subclass of BaseEnvManager. - policy (:obj:`Optional[namedtuple]`): An optional API namedtuple defining the policy for evaluation. - tb_logger (:obj:`Optional[SummaryWriter]`): Optional TensorBoard logger instance. - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. - instance_name (:obj:`str`): Name of this evaluator instance. - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. """ self._eval_freq = eval_freq self._exp_name = exp_name self._instance_name = instance_name # Logger (Monitor will be initialized in policy setter) # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output. if get_rank() == 0: if tb_logger is not None: self._logger, _ = build_logger( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name ) else: self._logger, self._tb_logger = None, None # for close elegantly self.reset(policy, env) self._timer = EasyTimer() self._default_n_episode = n_evaluator_episode self._stop_value = stop_value # ============================================================== # MCTS+RL related core code # ============================================================== self.policy_config = policy_config
[docs] def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: Reset the environment for the evaluator, optionally replacing it with a new environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the evaluator with the new passed in environment and launch. Arguments: - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. """ if _env is not None: self._env = _env self._env.launch() self._env_num = self._env.env_num else: self._env.reset()
[docs] def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: Reset the policy for the evaluator, optionally replacing it with a new policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy. Arguments: - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. """ assert hasattr(self, '_env'), "please set env first" if _policy is not None: self._policy = _policy self._policy.reset()
[docs] def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: Reset both the policy and environment for the evaluator, optionally replacing them. If _env is None, reset the old environment. If _env is not None, replace the old environment in the evaluator with the new passed in \ environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy. Arguments: - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. """ if _env is not None: self.reset_env(_env) if _policy is not None: self.reset_policy(_policy) self._max_episode_return = float("-inf") self._last_eval_iter = 0 self._end_flag = False
[docs] def close(self) -> None: """ Overview: Close the evaluator, the environment, flush and close the TensorBoard logger if applicable. """ if self._end_flag: return self._end_flag = True self._env.close() if self._tb_logger: self._tb_logger.flush() self._tb_logger.close()
def __del__(self): """ Overview: Execute the close command and close the evaluator. __del__ is automatically called \ to destroy the evaluator instance when the evaluator finishes its work """ self.close()
[docs] def should_eval(self, train_iter: int) -> bool: """ Overview: Determine whether to initiate evaluation based on the training iteration count and evaluation frequency. Arguments: - train_iter (:obj:`int`): The current count of training iterations. Returns: - (:obj:`bool`): `True` if evaluation should be initiated, otherwise `False`. """ if train_iter == self._last_eval_iter: return False if (train_iter - self._last_eval_iter) < self._eval_freq and train_iter != 0: return False self._last_eval_iter = train_iter return True
[docs] def eval( self, save_ckpt_fn: Callable = None, train_iter: int = -1, envstep: int = -1, n_episode: Optional[int] = None, return_trajectory: bool = False, ) -> Tuple[bool, float]: """ Overview: Evaluate the current policy, storing the best policy if it achieves the highest historical reward. Arguments: - save_ckpt_fn (:obj:`Optional[Callable]`): Optional function to save a checkpoint when a new best reward is achieved. - train_iter (:obj:`int`): The current training iteration count. - envstep (:obj:`int`): The current environment step count. - n_episode (:obj:`Optional[int]`): Optional number of evaluation episodes; defaults to the evaluator's setting. - return_trajectory (:obj:`bool`): Return the evaluated trajectory `game_segments` in `episode_info` if True. Returns: - stop_flag (:obj:`bool`): Indicates whether the training can be stopped based on the stop value. - episode_info (:obj:`Dict[str, Any]`): A dictionary containing information about the evaluation episodes. """ # the evaluator only works on rank0 episode_info = None stop_flag = False if get_rank() == 0: if n_episode is None: n_episode = self._default_n_episode assert n_episode is not None, "please indicate eval n_episode" envstep_count = 0 eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) env_nums = self._env.env_num self._env.reset() self._policy.reset() # initializations init_obs = self._env.ready_obs retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to # len(self._env.ready_obs), especially in tictactoe env. self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) time.sleep(retry_waiting_time) self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) self._logger.info( 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) ) init_obs = self._env.ready_obs action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} dones = np.array([False for _ in range(env_nums)]) game_segments = [ GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config ) for _ in range(env_nums) ] for i in range(env_nums): game_segments[i].reset( [to_ndarray(init_obs[i]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] ) ready_env_id = set() remain_episode = n_episode eps_steps_lst = np.zeros(env_nums) with self._timer: while not eval_monitor.is_finished(): # Get current ready env obs. obs = self._env.ready_obs new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] stack_obs = to_ndarray(stack_obs) stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== # policy forward # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_with_env_id = { k: v['root_sampled_actions'] for k, v in policy_output.items() } value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} visit_entropy_dict_with_env_id = { k: v['visit_count_distribution_entropy'] for k, v in policy_output.items() } actions = {} distributions_dict = {} if self.policy_config.sampled_algo: root_sampled_actions_dict = {} value_dict = {} pred_value_dict = {} visit_entropy_dict = {} for index, env_id in enumerate(ready_env_id): actions[env_id] = actions_with_env_id.pop(env_id) distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) if self.policy_config.sampled_algo: root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) value_dict[env_id] = value_dict_with_env_id.pop(env_id) pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) # ============================================================== # Interact with env. # ============================================================== timesteps = self._env.step(actions) timesteps = to_tensor(timesteps, dtype=torch.float32) for env_id, t in timesteps.items(): obs, reward, done, info = t.obs, t.reward, t.done, t.info eps_steps_lst[env_id] += 1 if self._policy.get_attribute('cfg').type == 'unizero': # only for UniZero now self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], to_play_dict[env_id] ) # NOTE: the position of code snippet is very important. # the obs['action_mask'] and obs['to_play'] are corresponding to next action action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) dones[env_id] = done if t.done: # Env reset is done by env_manager automatically. self._policy.reset([env_id]) reward = t.info['eval_episode_return'] saved_info = {'eval_episode_return': t.info['eval_episode_return']} if 'episode_info' in t.info: saved_info.update(t.info['episode_info']) eval_monitor.update_info(env_id, saved_info) eval_monitor.update_reward(env_id, reward) self._logger.info( "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() ) ) # reset the finished env and init game_segments if n_episode > self._env_num: # Get current ready env obs. init_obs = self._env.ready_obs retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to # len(self._env.ready_obs), especially in tictactoe env. self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) self._logger.info( 'Before sleeping, the _env_states is {}'.format(self._env._env_states) ) time.sleep(retry_waiting_time) self._logger.info( '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 ) self._logger.info( 'After sleeping {}s, the current _env_states is {}'.format( retry_waiting_time, self._env._env_states ) ) init_obs = self._env.ready_obs new_available_env_id = set(init_obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config ) game_segments[env_id].reset( [ init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num) ] ) eps_steps_lst[env_id] = 0 # Env reset is done by env_manager automatically. self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. ready_env_id.remove(env_id) envstep_count += 1 duration = self._timer.value episode_return = eval_monitor.get_episode_return() info = { 'train_iter': train_iter, 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), 'episode_count': n_episode, 'envstep_count': envstep_count, 'avg_envstep_per_episode': envstep_count / n_episode, 'evaluate_time': duration, 'avg_envstep_per_sec': envstep_count / duration, 'avg_time_per_episode': n_episode / duration, 'reward_mean': np.mean(episode_return), 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), 'reward_min': np.min(episode_return), # 'each_reward': episode_return, } episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) self._logger.info(self._logger.get_tabulate_vars_hor(info)) for k, v in info.items(): if k in ['train_iter', 'ckpt_name', 'each_reward']: continue if not np.isscalar(v): continue self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) episode_return = np.mean(episode_return) if episode_return > self._max_episode_return: if save_ckpt_fn: save_ckpt_fn('ckpt_best.pth.tar') self._max_episode_return = episode_return stop_flag = episode_return >= self._stop_value and train_iter > 0 if stop_flag: self._logger.info( "[LightZero serial pipeline] " + "Current episode_return: {} is greater than stop_value: {}".format(episode_return, self._stop_value) + ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." ) if get_world_size() > 1: objects = [stop_flag, episode_info] broadcast_object_list(objects, src=0) stop_flag, episode_info = objects episode_info = to_item(episode_info) if return_trajectory: episode_info['trajectory'] = game_segments return stop_flag, episode_info