123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- import random
- import threading
- from contextlib import contextmanager
- import etcd3
- from easyfl.registry import mock_etcd
- from easyfl.registry.vclient import VirtualClient
- class EtcdClient(object):
- """Etcd client to connect and communicate with etcd service.
- Etcd is the serves as the registry for remote training.
- Clients register themselves in etcd and server queries etcd to get client addresses.
- Args:
- name (str): The name of etcd.
- addrs (str): Etcd addresses, format: "<ip>:<port>,<ip>:<port>".
- base_dir (str): The prefix of all etcd requests, default to "backends".
- use_mock_etcd (bool): Whether use mocked etcd for testing.
- """
- ETCD_CLIENT_POOL_LOCK = threading.Lock()
- ETCD_CLIENT_POOL = {}
- ETCD_CLIENT_POOL_DESTROY = False
- class Event(object):
- def __init__(self, event, base_dir):
- self._event = event
- self._base_dir = base_dir
- def __getattr__(self, attr):
- return getattr(self._event, attr)
- @property
- def key(self):
- return EtcdClient.normalize_output_key(self._event.key, self._base_dir)
- def __init__(self, name, addrs, base_dir, use_mock_etcd=False):
- self._name = name
- self._base_dir = '/' + EtcdClient._normalize_input_key(base_dir)
- self._addrs = self._normalize_addr(addrs)
- if len(self._addrs) == 0:
- raise ValueError('Empty hosts EtcdClient')
- self._cur_addr_idx = random.randint(0, len(self._addrs) - 1)
- self._use_mock_etcd = use_mock_etcd
- def get_data(self, key):
- addr = self._get_next_addr()
- with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
- return clnt.get(self._generate_path(key))[0]
- def set_data(self, key, data):
- addr = self._get_next_addr()
- with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
- clnt.put(self._generate_path(key), data)
- def delete(self, key):
- addr = self._get_next_addr()
- with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
- return clnt.delete(self._generate_path(key))
- def delete_prefix(self, key):
- addr = self._get_next_addr()
- with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
- return clnt.delete_prefix(self._generate_path(key))
- def cas(self, key, old_data, new_data):
- addr = self._get_next_addr()
- with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
- etcd_path = self._generate_path(key)
- if old_data is None:
- return clnt.put_if_not_exists(etcd_path, new_data)
- return clnt.replace(etcd_path, old_data, new_data)
- def watch_key(self, key):
- addr = self._get_next_addr()
- with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
- notifier, cancel = clnt.watch(self._generate_path(key))
- def prefix_extractor(notifier, base_dir):
- while True:
- try:
- yield EtcdClient.Event(next(notifier), base_dir)
- except StopIteration:
- break
- return prefix_extractor(notifier, self._base_dir), cancel
- def get_prefix_kvs(self, prefix, ignore_prefix=False):
- addr = self._get_next_addr()
- kvs = []
- path = self._generate_path(prefix)
- with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
- for (data, key) in clnt.get_prefix(path, sort_order='ascend'):
- if ignore_prefix and key.key == path.encode():
- continue
- nkey = EtcdClient.normalize_output_key(key.key, self._base_dir)
- kvs.append((nkey, data))
- return kvs
- def get_clients(self, prefix):
- """Retrieve client addresses from etcd using prefix.
- Args:
- prefix (str): the prefix of clients addresses; default is the docker image name "easyfl-client"
- Returns:
- list[:obj:`VirtualClient`]: A list of clients.
- """
- key_value_tuples = self.get_prefix_kvs(prefix)
- clients = []
- index = 0
- for (key_byte, value_byte) in key_value_tuples:
- key, value = key_byte.decode("utf-8"), value_byte.decode("utf-8")
- parts = key.split("/")
- if len(parts) <= 1:
- continue
- addr = parts[1]
- if not self._is_addr(addr):
- continue
- clients.append(VirtualClient(value, addr, index))
- index += 1
- return clients
- def _is_addr(self, address):
- return len(address.split(":")) > 1
- def _generate_path(self, key):
- return '/'.join([self._base_dir, self._normalize_input_key(key)])
- def _get_next_addr(self):
- return self._addrs[random.randint(0, len(self._addrs) - 1)]
- @staticmethod
- def _normalize_addr(addrs):
- naddrs = []
- for raw_addr in addrs.split(','):
- (host, port_str) = raw_addr.split(':')
- try:
- port = int(port_str)
- if port < 0 or port > 65535:
- raise ValueError('port {} is out of range')
- except ValueError:
- raise ValueError('{} is not a valid port'.format(port_str))
- naddrs.append((host, port))
- return naddrs
- @staticmethod
- def _normalize_input_key(key):
- skip_cnt = 0
- while key[skip_cnt] == '.' or key[skip_cnt] == '/':
- skip_cnt += 1
- if skip_cnt > 0:
- return key[skip_cnt:]
- return key
- @staticmethod
- def normalize_output_key(key, base_dir):
- if isinstance(base_dir, str):
- assert key.startswith(base_dir.encode())
- else:
- assert key.startswith(base_dir)
- return key[len(base_dir) + 1:]
- @classmethod
- @contextmanager
- def closing(cls, name, addr, use_mock_etcd):
- clnt = None
- with cls.ETCD_CLIENT_POOL_LOCK:
- if (name in cls.ETCD_CLIENT_POOL and
- len(cls.ETCD_CLIENT_POOL[name]) > 0):
- clnt = cls.ETCD_CLIENT_POOL[name][0]
- cls.ETCD_CLIENT_POOL[name] = cls.ETCD_CLIENT_POOL[name][1:]
- if clnt is None:
- try:
- if use_mock_etcd:
- clnt = mock_etcd.MockEtcdClient(addr[0], addr[1])
- else:
- clnt = etcd3.client(host=addr[0], port=addr[1])
- except Exception as e:
- clnt.close()
- raise e
- try:
- yield clnt
- except Exception as e:
- clnt.close()
- raise e
- else:
- with cls.ETCD_CLIENT_POOL_LOCK:
- if cls.ETCD_CLIENT_POOL_DESTROY:
- clnt.close()
- else:
- if name not in cls.ETCD_CLIENT_POOL:
- cls.ETCD_CLIENT_POOL[name] = [clnt]
- else:
- cls.ETCD_CLIENT_POOL[name].append(clnt)
- @classmethod
- def destory_client_pool(cls):
- with cls.ETCD_CLIENT_POOL_LOCK:
- cls.ETCD_CLIENT_POOL_DESTROY = True
- for _, clnts in cls.ETCD_CLIENT_POOL.items():
- for clnt in clnts:
- clnt.close()
|