Source code for ding.worker.coordinator.comm_coordinator
import traceback
import time
import sys
import requests
from typing import Dict, Callable
from threading import Thread
from ding.utils import LockContext, LockContextType, get_operator_server_kwargs
from ding.interaction import Master
from ding.interaction.master.task import TaskStatus
from .resource_manager import NaiveResourceManager
from .operator_server import OperatorServer
[docs]class CommCoordinator(object):
r"""
Overview:
the communication part of coordinator(coordinator intercollector)
Interface:
__init__ , start, close, __del__, send_collector_task, send_learner_task
"""
[docs] def __init__(self, cfg: dict, callback_fn: Dict[str, Callable], logger: 'logging.Logger') -> None: # noqa
r"""
Overview:
init the interactor of coordinator
Arguments:
- cfg (:obj:`dict`): The config file of communication coordinator
- callback_fn (:obj:`Dict[str, Callable]`): The callback functions given by coordinator
- logger (:obj:`logging.Logger`): The text logger.
"""
self._cfg = cfg
self._callback_fn = callback_fn
self._logger = logger
self._max_retry_second = 120
self._end_flag = True
self._connection_collector = {}
self._connection_learner = {}
self._resource_manager = NaiveResourceManager()
self._remain_task_lock = LockContext(LockContextType.THREAD_LOCK)
self._remain_collector_task = set()
self._remain_learner_task = set()
if self._cfg.operator_server:
server_kwargs = get_operator_server_kwargs(self._cfg.operator_server)
self._operator_server = OperatorServer(**server_kwargs)
self._operator_server.set_worker_type('coordinator')
self._collector_target_num = self._cfg.operator_server.collector_target_num
self._learner_target_num = self._cfg.operator_server.learner_target_num
else:
self._operator_server = None
# for update resource
self._resource_lock = LockContext(LockContextType.THREAD_LOCK)
# failed connection
self._failed_learner_conn = set()
self._failed_collector_conn = set()
[docs] def start(self) -> None:
r"""
Overview:
start the coordinator interactor and manage resources and connections
"""
self._end_flag = False
self._master = Master(self._cfg.host, self._cfg.port)
self._master.start()
self._master.ping()
# new connection from config
for _, (learner_id, learner_host, learner_port) in self._cfg.learner.items():
self._new_connection_learner(learner_id, learner_host, learner_port)
for _, (collector_id, collector_host, collector_port) in self._cfg.collector.items():
self._new_connection_collector(collector_id, collector_host, collector_port)
if self._operator_server:
# post init learner/collector demand
start_time, init_flag = time.time(), False
while time.time() - start_time <= self._max_retry_second and not self._end_flag:
success, _, message, _ = self._operator_server.post_replicas(
self._cfg.operator_server.init_replicas_request
)
if success:
self._logger.info("Post replicas demand to server successfully")
init_flag = True
break
else:
self._logger.info("Failed to post replicas request to server, message: {}".format(message))
time.sleep(2)
if not init_flag:
self._logger.info('Exit since cannot request replicas to operator-server...')
self.close()
sys.exit(1)
# create sync learner/collector thread
self._period_sync_with_server_thread = Thread(
target=self._period_sync_with_server, name="period_sync", daemon=True
)
self._period_sync_with_server_thread.start()
# wait for enough collector/learner
start_time = time.time()
enough_flag = False
while time.time() - start_time <= self._max_retry_second:
if len(self._connection_collector) < self._collector_target_num and len(self._connection_learner
) < self._learner_target_num:
self._logger.info(
"Only can connect {} collectors, {} learners.".format(
len(self._connection_collector), len(self._connection_learner)
)
)
time.sleep(2)
else:
self._logger.info(
"Have connected {} collectors, {} learners, match limit requests.".format(
len(self._connection_collector), len(self._connection_learner)
)
)
self._logger.info("Total DI-engine pipeline start...")
enough_flag = True
break
if not enough_flag:
self._logger.error(
"Exit since only can connect {} collectors, {} learners.".format(
len(self._connection_collector), len(self._connection_learner)
)
)
self.close()
sys.exit(1)
if self._end_flag:
self._logger.error("connection max retries failed")
sys.exit(1)
def _new_connection_collector(
self,
collector_id: str,
collector_host: str,
collector_port: int,
increase_task_space: bool = False,
) -> None:
start_time = time.time()
conn = None
while time.time() - start_time <= self._max_retry_second and not self._end_flag:
try:
if conn is None or not conn.is_connected:
conn = self._master.new_connection(collector_id, collector_host, collector_port)
conn.connect()
assert conn.is_connected
resource_task = self._get_resource(conn)
if resource_task.status != TaskStatus.COMPLETED:
self._logger.error("can't acquire resource for collector({})".format(collector_id))
continue
else:
with self._resource_lock:
self._resource_manager.update('collector', collector_id, resource_task.result)
self._connection_collector[collector_id] = conn
if increase_task_space:
self._callback_fn['deal_with_increase_collector']()
break
except Exception as e:
self._logger.error(
f"Collector({collector_id}) connection start error:\n" +
''.join(traceback.format_tb(e.__traceback__)) + repr(e) + '\nAuto Retry...'
)
time.sleep(2)
if collector_id in self._connection_collector:
self._logger.info(f"Succeed to connect to collector({collector_id})")
else:
self._logger.info(f"Fail to connect to collector({collector_id})")
self._failed_collector_conn.add(collector_id)
def _new_connection_learner(self, learner_id: str, learner_host: str, learner_port: int) -> None:
start_time = time.time()
conn = None
while time.time() - start_time <= self._max_retry_second and not self._end_flag:
try:
if conn is None or not conn.is_connected:
conn = self._master.new_connection(learner_id, learner_host, learner_port)
conn.connect()
assert conn.is_connected
resource_task = self._get_resource(conn)
if resource_task.status != TaskStatus.COMPLETED:
self._logger.error("can't acquire resource for learner({})".format(learner_id))
continue
else:
with self._resource_lock:
self._resource_manager.update('learner', learner_id, resource_task.result)
self._connection_learner[learner_id] = conn
break
except Exception as e:
self._logger.error(
f"learner({learner_id}) connection start error:\n" + ''.join(traceback.format_tb(e.__traceback__)) +
repr(e) + '\nAuto Retry...'
)
time.sleep(2)
if learner_id in self._connection_learner:
self._logger.info(f"Succeed to connect to learner({learner_id})")
else:
self._logger.info(f"Fail to connect to learner({learner_id})")
self._failed_learner_conn.add(learner_id)
[docs] def close(self) -> None:
r"""
Overview:
close the coordinator interactor
"""
if self._end_flag:
return
self._end_flag = True
# wait for execute thread
start_time = time.time()
# TODO
if self._operator_server:
self._period_sync_with_server_thread.join()
# wait from all slave receive DELETE
time.sleep(5)
while time.time() - start_time <= 60:
if len(self._remain_learner_task) == 0 and len(self._remain_collector_task) == 0:
break
else:
time.sleep(1)
for collector_id, conn in self._connection_collector.items():
conn.disconnect()
assert not conn.is_connected
for learner_id, conn in self._connection_learner.items():
conn.disconnect()
assert not conn.is_connected
self._master.close()
[docs] def __del__(self) -> None:
r"""
Overview:
__del__ method will close the coordinator interactor
"""
self.close()
def _get_resource(self, conn: 'Connection') -> 'TaskResult': # noqa
r"""
Overview:
get the resources according to connection
Arguments:
- conn (:obj:`Connection`): the connection to get resource_task
"""
resource_task = conn.new_task({'name': 'resource'})
resource_task.start().join()
return resource_task
[docs] def send_collector_task(self, collector_task: dict) -> bool:
r"""
Overview:
send the collector_task to collector_task threads and execute
Arguments:
- collector_task (:obj:`dict`): the collector_task to send
"""
# assert not self._end_flag, "please start interaction first"
task_id = collector_task['task_id']
# according to resource info, assign task to a specific collector and adapt task
assigned_collector = self._resource_manager.assign_collector(collector_task)
if assigned_collector is None:
self._logger.error("collector task({}) doesn't have enough collector to execute".format(task_id))
return False
collector_task.update(assigned_collector)
collector_id = collector_task['collector_id']
start_task = self._connection_collector[collector_id].new_task(
{
'name': 'collector_start_task',
'task_info': collector_task
}
)
start_task.start().join()
if start_task.status != TaskStatus.COMPLETED:
self._resource_manager.update(
'collector', assigned_collector['collector_id'], assigned_collector['resource_info']
)
self._logger.error('collector_task({}) start failed: {}'.format(task_id, start_task.result))
return False
else:
self._logger.info('collector task({}) is assigned to collector({})'.format(task_id, collector_id))
with self._remain_task_lock:
self._remain_collector_task.add(task_id)
collector_task_thread = Thread(
target=self._execute_collector_task, args=(collector_task, ), name='coordinator_collector_task'
)
collector_task_thread.start()
return True
def _execute_collector_task(self, collector_task: dict) -> None:
r"""
Overview:
execute the collector task
Arguments:
- collector_task (:obj:`dict`): the collector task to execute
"""
close_flag = False
collector_id = collector_task['collector_id']
while not self._end_flag:
try:
# data task
data_task = self._connection_collector[collector_id].new_task({'name': 'collector_data_task'})
self._logger.info('collector data task begin')
data_task.start().join()
self._logger.info('collector data task end')
if data_task.status != TaskStatus.COMPLETED:
# TODO(deal with fail task)
self._logger.error('collector data task is failed')
continue
result = data_task.result
task_id = result.get('task_id', None)
# data result
if 'data_id' in result:
buffer_id = result.get('buffer_id', None)
data_id = result.get('data_id', None)
self._callback_fn['deal_with_collector_send_data'](task_id, buffer_id, data_id, result)
# info result
else:
is_finished = self._callback_fn['deal_with_collector_judge_finish'](task_id, result)
if not is_finished:
continue
# close task
self._logger.error('close_task: {}\n{}'.format(task_id, result))
close_task = self._connection_collector[collector_id].new_task({'name': 'collector_close_task'})
close_task.start().join()
if close_task.status != TaskStatus.COMPLETED:
# TODO(deal with fail task)
self._logger.error('collector close is failed')
break
result = close_task.result
task_id = result.get('task_id', None)
self._callback_fn['deal_with_collector_finish_task'](task_id, result)
resource_task = self._get_resource(self._connection_collector[collector_id])
if resource_task.status == TaskStatus.COMPLETED:
self._resource_manager.update('collector', collector_id, resource_task.result)
close_flag = True
break
except requests.exceptions.HTTPError as e:
if self._end_flag:
break
else:
raise e
if not close_flag:
close_task = self._connection_collector[collector_id].new_task({'name': 'collector_close_task'})
close_task.start().join()
with self._remain_task_lock:
self._remain_collector_task.remove(task_id)
[docs] def send_learner_task(self, learner_task: dict) -> bool:
r"""
Overview:
send the learner_task to learner_task threads and execute
Arguments:
- learner_task (:obj:`dict`): the learner_task to send
"""
# assert not self._end_flag, "please start interaction first"
task_id = learner_task['task_id']
assigned_learner = self._resource_manager.assign_learner(learner_task)
if assigned_learner is None:
self._logger.error("learner task({}) doesn't have enough learner to execute".format(task_id))
return False
learner_task.update(assigned_learner)
learner_id = learner_task['learner_id']
start_task = self._connection_learner[learner_id].new_task(
{
'name': 'learner_start_task',
'task_info': learner_task
}
)
start_task.start().join()
if start_task.status != TaskStatus.COMPLETED:
self._resource_manager.update('learner', assigned_learner['learner_id'], assigned_learner['resource_info'])
self._logger.info('learner_task({}) start failed: {}'.format(task_id, start_task.result))
return False
else:
self._logger.info('learner task({}) is assigned to learner({})'.format(task_id, learner_id))
with self._remain_task_lock:
self._remain_learner_task.add(task_id)
learner_task_thread = Thread(
target=self._execute_learner_task, args=(learner_task, ), name='coordinator_learner_task'
)
learner_task_thread.start()
return True
def _execute_learner_task(self, learner_task: dict) -> None:
r"""
Overview:
execute the learner task
Arguments:
- learner_task (:obj:`dict`): the learner task to execute
"""
close_flag = False
learner_id = learner_task['learner_id']
while not self._end_flag:
try:
# get data
get_data_task = self._connection_learner[learner_id].new_task({'name': 'learner_get_data_task'})
get_data_task.start().join()
if get_data_task.status != TaskStatus.COMPLETED:
# TODO(deal with fail task)
self._logger.error('learner get_data_task failed: {}'.format(get_data_task.result))
continue
result = get_data_task.result
task_id, buffer_id, batch_size = result['task_id'], result['buffer_id'], result['batch_size']
cur_learner_iter = result['cur_learner_iter']
sleep_count = 1
while True:
data = self._callback_fn['deal_with_learner_get_data'](
task_id, buffer_id, batch_size, cur_learner_iter
)
if self._end_flag or data is not None:
self._logger.info('sample result is ok')
break
else:
self._logger.info('sample result is None')
time.sleep(sleep_count)
sleep_count += 2
if self._end_flag:
break
# learn task
learn_task = self._connection_learner[learner_id].new_task({'name': 'learner_learn_task', 'data': data})
learn_task.start().join()
if learn_task.status != TaskStatus.COMPLETED:
# TODO(deal with fail task)
self._logger.error('learner learn_task failed: {}'.format(learn_task.result))
continue
result = learn_task.result
task_id, info = result['task_id'], result['info']
is_finished = self._callback_fn['deal_with_learner_judge_finish'](task_id, info)
if is_finished:
# close task and update resource
close_task = self._connection_learner[learner_id].new_task({'name': 'learner_close_task'})
close_task.start().join()
if close_task.status != TaskStatus.COMPLETED:
self._logger.error('learner close_task failed: {}'.format(close_task.result))
break
result = close_task.result
task_id = result.get('task_id', None)
self._callback_fn['deal_with_learner_finish_task'](task_id, result)
resource_task = self._get_resource(self._connection_learner[learner_id])
if resource_task.status == TaskStatus.COMPLETED:
self._resource_manager.update('learner', learner_id, resource_task.result)
close_flag = True
break
else:
# update info
buffer_id = result['buffer_id']
self._callback_fn['deal_with_learner_send_info'](task_id, buffer_id, info)
except requests.exceptions.HTTPError as e:
if self._end_flag:
break
else:
raise e
if not close_flag:
close_task = self._connection_learner[learner_id].new_task({'name': 'learner_close_task'})
close_task.start().join()
with self._remain_task_lock:
self._remain_learner_task.remove(task_id)
def _period_sync_with_server(self) -> None:
while not self._end_flag:
# First: send failed list to notify DI-engine server which replicas are failed,
# then terminate such replicas.
# self._logger.info("failed list:", list(self._failed_collector_conn), list(self._failed_learner_conn))
if len(self._failed_learner_conn) > 0 or len(self._failed_collector_conn) > 0:
collector_conn = []
for replica_conn in self._failed_collector_conn:
dns_name = replica_conn.split(":")[0]
pod_name_list = dns_name.split(".")[:-1]
pod_name = ".".join(pod_name_list)
collector_conn.append(pod_name)
learner_conn = []
for replica_conn in self._failed_learner_conn:
dns_name = replica_conn.split(":")[0]
pod_name_list = dns_name.split(".")[:-1]
pod_name = ".".join(pod_name_list)
learner_conn.append(pod_name)
success, _, message, _ = self._operator_server.post_replicas_failed(
learners=list(learner_conn), collectors=list(collector_conn)
)
if success:
# do not update collector or learner instantly, update at /GET replicas
self._failed_collector_conn.clear()
self._failed_learner_conn.clear()
else:
self._logger.error("Failed to send failed list to server, message: {}".format(message))
# get list from server
success, _, message, data = self._operator_server.get_replicas()
if success:
cur_collectors = data["collectors"]
cur_learners = data["learners"]
# self._logger.info("current list:", cur_collectors, cur_learners)
self._update_connection_collector(cur_collectors)
self._update_connection_learner(cur_learners)
else:
self._logger.error("Failed to sync with server, message: {}".format(message))
time.sleep(1)
def _update_connection_collector(self, cur_collectors: list) -> None:
conn_collectors = list(self._connection_collector.keys())
new_c = set(cur_collectors) - set(conn_collectors)
del_c = set(conn_collectors) - (set(cur_collectors) | self._failed_collector_conn)
# conns which have terminated in server side, clear up
self._failed_collector_conn = self._failed_collector_conn & set(cur_collectors)
# connect to each new collector
for collector_id in new_c:
collector_host, collector_port = collector_id.split(':')
self._new_connection_collector(collector_id, collector_host, int(collector_port), True)
for collector_id in del_c:
if collector_id in conn_collectors:
# TODO(nyz) whether to need to close task first
with self._resource_lock:
if not self._resource_manager.have_assigned('collector', collector_id):
self._resource_manager.delete("collector", collector_id)
if self._connection_collector[collector_id].is_connected:
conn = self._connection_collector.pop(collector_id)
conn.disconnect()
assert not conn.is_connected
self._callback_fn['deal_with_decrease_collector']()
else:
# ignore the operation of disconnect, since the pod will be terminated by server,
# just throw the connection
self._connection_collector.pop(collector_id)
def _update_connection_learner(self, cur_learners) -> None:
conn_learners = list(self._connection_learner.keys())
new_c = set(cur_learners) - set(conn_learners)
del_c = set(conn_learners) - (set(cur_learners) | self._failed_learner_conn)
# conns which have terminated in server side, clear up
self._failed_learner_conn = self._failed_learner_conn & set(cur_learners)
# connect to each new learner
for learner_id in new_c:
learner_host, learner_port = learner_id.split(':')
self._new_connection_learner(learner_id, learner_host, int(learner_port))
for learner_id in del_c:
if learner_id in conn_learners:
# TODO(nyz) whether to need to close task first
with self._resource_lock:
if not self._resource_manager.have_assigned('learner', learner_id):
self._resource_manager.delete("learner", learner_id)
if self._connection_learner[learner_id].is_connected:
conn = self._connection_learner.pop(learner_id)
conn.disconnect()
assert not conn.is_connected
else:
# ignore the operation of disconnect, since the pod will be terminated by server,
# just throw the connection
self._connection_learner.pop(learner_id)