Shortcuts

Source code for ding.worker.collector.comm.flask_fs_collector

import os
import time
from typing import Union, Dict, Callable
from queue import Queue
from threading import Thread

from ding.utils import read_file, save_file, COMM_COLLECTOR_REGISTRY
from ding.utils.file_helper import save_to_di_store
from ding.interaction import Slave, TaskFail
from .base_comm_collector import BaseCommCollector


[docs]class CollectorSlave(Slave): """ Overview: A slave, whose master is coordinator. Used to pass message between comm collector and coordinator. Interfaces: __init__, _process_task """ # override
[docs] def __init__(self, *args, callback_fn: Dict[str, Callable], **kwargs) -> None: """ Overview: Init callback functions additionally. Callback functions are methods in comm collector. """ super().__init__(*args, **kwargs) self._callback_fn = callback_fn self._current_task_info = None
[docs] def _process_task(self, task: dict) -> Union[dict, TaskFail]: """ Overview: Process a task according to input task info dict, which is passed in by master coordinator. For each type of task, you can refer to corresponding callback function in comm collector for details. Arguments: - cfg (:obj:`EasyDict`): Task dict. Must contain key "name". Returns: - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception. """ task_name = task['name'] if task_name == 'resource': return self._callback_fn['deal_with_resource']() elif task_name == 'collector_start_task': self._current_task_info = task['task_info'] self._callback_fn['deal_with_collector_start'](self._current_task_info) return {'message': 'collector task has started'} elif task_name == 'collector_data_task': data = self._callback_fn['deal_with_collector_data']() data['buffer_id'] = self._current_task_info['buffer_id'] data['task_id'] = self._current_task_info['task_id'] return data elif task_name == 'collector_close_task': data = self._callback_fn['deal_with_collector_close']() data['task_id'] = self._current_task_info['task_id'] return data else: raise TaskFail( result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name) )
[docs]@COMM_COLLECTOR_REGISTRY.register('flask_fs') class FlaskFileSystemCollector(BaseCommCollector): """ Overview: An implementation of CommLearner, using flask and the file system. Interfaces: __init__, deal_with_resource, deal_with_collector_start, deal_with_collector_data, deal_with_collector_close,\ get_policy_update_info, send_stepdata, send_metadata, start, close """ # override
[docs] def __init__(self, cfg: dict) -> None: """ Overview: Initialization method. Arguments: - cfg (:obj:`EasyDict`): Config dict """ BaseCommCollector.__init__(self, cfg) host, port = cfg.host, cfg.port self._callback_fn = { 'deal_with_resource': self.deal_with_resource, 'deal_with_collector_start': self.deal_with_collector_start, 'deal_with_collector_data': self.deal_with_collector_data, 'deal_with_collector_close': self.deal_with_collector_close, } self._slave = CollectorSlave(host, port, callback_fn=self._callback_fn) self._path_policy = cfg.path_policy self._path_data = cfg.path_data if not os.path.exists(self._path_data): try: os.mkdir(self._path_data) except Exception as e: pass self._metadata_queue = Queue(8) self._collector_close_flag = False self._collector = None
[docs] def deal_with_resource(self) -> dict: """ Overview: Callback function in ``CollectorSlave``. Return how many resources are needed to start current collector. Returns: - resource (:obj:`dict`): Resource info dict, including ['gpu', 'cpu']. """ return {'gpu': 1, 'cpu': 20}
[docs] def deal_with_collector_start(self, task_info: dict) -> None: """ Overview: Callback function in ``CollectorSlave``. Create a collector and start a collector thread of the created one. Arguments: - task_info (:obj:`dict`): Task info dict. Note: In ``_create_collector`` method in base class ``BaseCommCollector``, 4 methods 'send_metadata', 'send_stepdata', 'get_policy_update_info', and policy are set. You can refer to it for details. """ self._collector_close_flag = False self._collector = self._create_collector(task_info) self._collector_thread = Thread(target=self._collector.start, args=(), daemon=True, name='collector_start') self._collector_thread.start()
[docs] def deal_with_collector_data(self) -> dict: """ Overview: Callback function in ``CollectorSlave``. Get data sample dict from ``_metadata_queue``, which will be sent to coordinator afterwards. Returns: - data (:obj:`Any`): Data sample dict. """ while True: if not self._metadata_queue.empty(): data = self._metadata_queue.get() break else: time.sleep(0.1) return data
def deal_with_collector_close(self) -> dict: self._collector_close_flag = True finish_info = self._collector.get_finish_info() self._collector.close() self._collector_thread.join() del self._collector_thread self._collector = None return finish_info # override
[docs] def get_policy_update_info(self, path: str) -> dict: """ Overview: Get policy information in corresponding path. Arguments: - path (:obj:`str`): path to policy update information. """ if self._collector_close_flag: return if self._path_policy not in path: path = os.path.join(self._path_policy, path) return read_file(path, use_lock=True)
# override
[docs] def send_stepdata(self, path: str, stepdata: list) -> None: """ Overview: Save collector's step data in corresponding path. Arguments: - path (:obj:`str`): Path to save data. - stepdata (:obj:`Any`): Data of one step. """ if save_to_di_store: if self._collector_close_flag: return b'0' * 20 # return an object reference that doesn't exist object_ref = save_to_di_store(stepdata) # print('send_stepdata:', path, 'object ref:', object_ref, 'len:', len(stepdata)) return object_ref if self._collector_close_flag: return name = os.path.join(self._path_data, path) save_file(name, stepdata, use_lock=False)
# override
[docs] def send_metadata(self, metadata: dict) -> None: """ Overview: Store learn info dict in queue, which will be retrieved by callback function "deal_with_collector_learn" in collector slave, then will be sent to coordinator. Arguments: - metadata (:obj:`Any`): meta data. """ if self._collector_close_flag: return necessary_metadata_keys = set(['data_id', 'policy_iter']) necessary_info_keys = set(['collector_done', 'cur_episode', 'cur_sample', 'cur_step']) assert necessary_metadata_keys.issubset(set(metadata.keys()) ) or necessary_info_keys.issubset(set(metadata.keys())) while True: if not self._metadata_queue.full(): self._metadata_queue.put(metadata) break else: time.sleep(0.1)
[docs] def start(self) -> None: """ Overview: Start comm collector itself and the collector slave. """ BaseCommCollector.start(self) self._slave.start()
[docs] def close(self) -> None: """ Overview: Close comm collector itself and the collector slave. """ if self._end_flag: return total_sleep_count = 0 while self._collector is not None and total_sleep_count < 10: self._collector.info("please first close collector") time.sleep(1) total_sleep_count += 1 self._slave.close() BaseCommCollector.close(self)
def __del__(self) -> None: self.close()