import random import threading from contextlib import contextmanager import etcd3 from easyfl.registry import mock_etcd from easyfl.registry.vclient import VirtualClient class EtcdClient(object): """Etcd client to connect and communicate with etcd service. Etcd is the serves as the registry for remote training. Clients register themselves in etcd and server queries etcd to get client addresses. Args: name (str): The name of etcd. addrs (str): Etcd addresses, format: ":,:". base_dir (str): The prefix of all etcd requests, default to "backends". use_mock_etcd (bool): Whether use mocked etcd for testing. """ ETCD_CLIENT_POOL_LOCK = threading.Lock() ETCD_CLIENT_POOL = {} ETCD_CLIENT_POOL_DESTROY = False class Event(object): def __init__(self, event, base_dir): self._event = event self._base_dir = base_dir def __getattr__(self, attr): return getattr(self._event, attr) @property def key(self): return EtcdClient.normalize_output_key(self._event.key, self._base_dir) def __init__(self, name, addrs, base_dir, use_mock_etcd=False): self._name = name self._base_dir = '/' + EtcdClient._normalize_input_key(base_dir) self._addrs = self._normalize_addr(addrs) if len(self._addrs) == 0: raise ValueError('Empty hosts EtcdClient') self._cur_addr_idx = random.randint(0, len(self._addrs) - 1) self._use_mock_etcd = use_mock_etcd def get_data(self, key): addr = self._get_next_addr() with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt: return clnt.get(self._generate_path(key))[0] def set_data(self, key, data): addr = self._get_next_addr() with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt: clnt.put(self._generate_path(key), data) def delete(self, key): addr = self._get_next_addr() with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt: return clnt.delete(self._generate_path(key)) def delete_prefix(self, key): addr = self._get_next_addr() with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt: return clnt.delete_prefix(self._generate_path(key)) def cas(self, key, old_data, new_data): addr = self._get_next_addr() with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt: etcd_path = self._generate_path(key) if old_data is None: return clnt.put_if_not_exists(etcd_path, new_data) return clnt.replace(etcd_path, old_data, new_data) def watch_key(self, key): addr = self._get_next_addr() with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt: notifier, cancel = clnt.watch(self._generate_path(key)) def prefix_extractor(notifier, base_dir): while True: try: yield EtcdClient.Event(next(notifier), base_dir) except StopIteration: break return prefix_extractor(notifier, self._base_dir), cancel def get_prefix_kvs(self, prefix, ignore_prefix=False): addr = self._get_next_addr() kvs = [] path = self._generate_path(prefix) with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt: for (data, key) in clnt.get_prefix(path, sort_order='ascend'): if ignore_prefix and key.key == path.encode(): continue nkey = EtcdClient.normalize_output_key(key.key, self._base_dir) kvs.append((nkey, data)) return kvs def get_clients(self, prefix): """Retrieve client addresses from etcd using prefix. Args: prefix (str): the prefix of clients addresses; default is the docker image name "easyfl-client" Returns: list[:obj:`VirtualClient`]: A list of clients. """ key_value_tuples = self.get_prefix_kvs(prefix) clients = [] index = 0 for (key_byte, value_byte) in key_value_tuples: key, value = key_byte.decode("utf-8"), value_byte.decode("utf-8") parts = key.split("/") if len(parts) <= 1: continue addr = parts[1] if not self._is_addr(addr): continue clients.append(VirtualClient(value, addr, index)) index += 1 return clients def _is_addr(self, address): return len(address.split(":")) > 1 def _generate_path(self, key): return '/'.join([self._base_dir, self._normalize_input_key(key)]) def _get_next_addr(self): return self._addrs[random.randint(0, len(self._addrs) - 1)] @staticmethod def _normalize_addr(addrs): naddrs = [] for raw_addr in addrs.split(','): (host, port_str) = raw_addr.split(':') try: port = int(port_str) if port < 0 or port > 65535: raise ValueError('port {} is out of range') except ValueError: raise ValueError('{} is not a valid port'.format(port_str)) naddrs.append((host, port)) return naddrs @staticmethod def _normalize_input_key(key): skip_cnt = 0 while key[skip_cnt] == '.' or key[skip_cnt] == '/': skip_cnt += 1 if skip_cnt > 0: return key[skip_cnt:] return key @staticmethod def normalize_output_key(key, base_dir): if isinstance(base_dir, str): assert key.startswith(base_dir.encode()) else: assert key.startswith(base_dir) return key[len(base_dir) + 1:] @classmethod @contextmanager def closing(cls, name, addr, use_mock_etcd): clnt = None with cls.ETCD_CLIENT_POOL_LOCK: if (name in cls.ETCD_CLIENT_POOL and len(cls.ETCD_CLIENT_POOL[name]) > 0): clnt = cls.ETCD_CLIENT_POOL[name][0] cls.ETCD_CLIENT_POOL[name] = cls.ETCD_CLIENT_POOL[name][1:] if clnt is None: try: if use_mock_etcd: clnt = mock_etcd.MockEtcdClient(addr[0], addr[1]) else: clnt = etcd3.client(host=addr[0], port=addr[1]) except Exception as e: clnt.close() raise e try: yield clnt except Exception as e: clnt.close() raise e else: with cls.ETCD_CLIENT_POOL_LOCK: if cls.ETCD_CLIENT_POOL_DESTROY: clnt.close() else: if name not in cls.ETCD_CLIENT_POOL: cls.ETCD_CLIENT_POOL[name] = [clnt] else: cls.ETCD_CLIENT_POOL[name].append(clnt) @classmethod def destory_client_pool(cls): with cls.ETCD_CLIENT_POOL_LOCK: cls.ETCD_CLIENT_POOL_DESTROY = True for _, clnts in cls.ETCD_CLIENT_POOL.items(): for clnt in clnts: clnt.close()