etcd_client.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import random
  2. import threading
  3. from contextlib import contextmanager
  4. import etcd3
  5. from easyfl.registry import mock_etcd
  6. from easyfl.registry.vclient import VirtualClient
  7. class EtcdClient(object):
  8. """Etcd client to connect and communicate with etcd service.
  9. Etcd is the serves as the registry for remote training.
  10. Clients register themselves in etcd and server queries etcd to get client addresses.
  11. Args:
  12. name (str): The name of etcd.
  13. addrs (str): Etcd addresses, format: "<ip>:<port>,<ip>:<port>".
  14. base_dir (str): The prefix of all etcd requests, default to "backends".
  15. use_mock_etcd (bool): Whether use mocked etcd for testing.
  16. """
  17. ETCD_CLIENT_POOL_LOCK = threading.Lock()
  18. ETCD_CLIENT_POOL = {}
  19. ETCD_CLIENT_POOL_DESTROY = False
  20. class Event(object):
  21. def __init__(self, event, base_dir):
  22. self._event = event
  23. self._base_dir = base_dir
  24. def __getattr__(self, attr):
  25. return getattr(self._event, attr)
  26. @property
  27. def key(self):
  28. return EtcdClient.normalize_output_key(self._event.key, self._base_dir)
  29. def __init__(self, name, addrs, base_dir, use_mock_etcd=False):
  30. self._name = name
  31. self._base_dir = '/' + EtcdClient._normalize_input_key(base_dir)
  32. self._addrs = self._normalize_addr(addrs)
  33. if len(self._addrs) == 0:
  34. raise ValueError('Empty hosts EtcdClient')
  35. self._cur_addr_idx = random.randint(0, len(self._addrs) - 1)
  36. self._use_mock_etcd = use_mock_etcd
  37. def get_data(self, key):
  38. addr = self._get_next_addr()
  39. with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
  40. return clnt.get(self._generate_path(key))[0]
  41. def set_data(self, key, data):
  42. addr = self._get_next_addr()
  43. with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
  44. clnt.put(self._generate_path(key), data)
  45. def delete(self, key):
  46. addr = self._get_next_addr()
  47. with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
  48. return clnt.delete(self._generate_path(key))
  49. def delete_prefix(self, key):
  50. addr = self._get_next_addr()
  51. with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
  52. return clnt.delete_prefix(self._generate_path(key))
  53. def cas(self, key, old_data, new_data):
  54. addr = self._get_next_addr()
  55. with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
  56. etcd_path = self._generate_path(key)
  57. if old_data is None:
  58. return clnt.put_if_not_exists(etcd_path, new_data)
  59. return clnt.replace(etcd_path, old_data, new_data)
  60. def watch_key(self, key):
  61. addr = self._get_next_addr()
  62. with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
  63. notifier, cancel = clnt.watch(self._generate_path(key))
  64. def prefix_extractor(notifier, base_dir):
  65. while True:
  66. try:
  67. yield EtcdClient.Event(next(notifier), base_dir)
  68. except StopIteration:
  69. break
  70. return prefix_extractor(notifier, self._base_dir), cancel
  71. def get_prefix_kvs(self, prefix, ignore_prefix=False):
  72. addr = self._get_next_addr()
  73. kvs = []
  74. path = self._generate_path(prefix)
  75. with EtcdClient.closing(self._name, addr, self._use_mock_etcd) as clnt:
  76. for (data, key) in clnt.get_prefix(path, sort_order='ascend'):
  77. if ignore_prefix and key.key == path.encode():
  78. continue
  79. nkey = EtcdClient.normalize_output_key(key.key, self._base_dir)
  80. kvs.append((nkey, data))
  81. return kvs
  82. def get_clients(self, prefix):
  83. """Retrieve client addresses from etcd using prefix.
  84. Args:
  85. prefix (str): the prefix of clients addresses; default is the docker image name "easyfl-client"
  86. Returns:
  87. list[:obj:`VirtualClient`]: A list of clients.
  88. """
  89. key_value_tuples = self.get_prefix_kvs(prefix)
  90. clients = []
  91. index = 0
  92. for (key_byte, value_byte) in key_value_tuples:
  93. key, value = key_byte.decode("utf-8"), value_byte.decode("utf-8")
  94. parts = key.split("/")
  95. if len(parts) <= 1:
  96. continue
  97. addr = parts[1]
  98. if not self._is_addr(addr):
  99. continue
  100. clients.append(VirtualClient(value, addr, index))
  101. index += 1
  102. return clients
  103. def _is_addr(self, address):
  104. return len(address.split(":")) > 1
  105. def _generate_path(self, key):
  106. return '/'.join([self._base_dir, self._normalize_input_key(key)])
  107. def _get_next_addr(self):
  108. return self._addrs[random.randint(0, len(self._addrs) - 1)]
  109. @staticmethod
  110. def _normalize_addr(addrs):
  111. naddrs = []
  112. for raw_addr in addrs.split(','):
  113. (host, port_str) = raw_addr.split(':')
  114. try:
  115. port = int(port_str)
  116. if port < 0 or port > 65535:
  117. raise ValueError('port {} is out of range')
  118. except ValueError:
  119. raise ValueError('{} is not a valid port'.format(port_str))
  120. naddrs.append((host, port))
  121. return naddrs
  122. @staticmethod
  123. def _normalize_input_key(key):
  124. skip_cnt = 0
  125. while key[skip_cnt] == '.' or key[skip_cnt] == '/':
  126. skip_cnt += 1
  127. if skip_cnt > 0:
  128. return key[skip_cnt:]
  129. return key
  130. @staticmethod
  131. def normalize_output_key(key, base_dir):
  132. if isinstance(base_dir, str):
  133. assert key.startswith(base_dir.encode())
  134. else:
  135. assert key.startswith(base_dir)
  136. return key[len(base_dir) + 1:]
  137. @classmethod
  138. @contextmanager
  139. def closing(cls, name, addr, use_mock_etcd):
  140. clnt = None
  141. with cls.ETCD_CLIENT_POOL_LOCK:
  142. if (name in cls.ETCD_CLIENT_POOL and
  143. len(cls.ETCD_CLIENT_POOL[name]) > 0):
  144. clnt = cls.ETCD_CLIENT_POOL[name][0]
  145. cls.ETCD_CLIENT_POOL[name] = cls.ETCD_CLIENT_POOL[name][1:]
  146. if clnt is None:
  147. try:
  148. if use_mock_etcd:
  149. clnt = mock_etcd.MockEtcdClient(addr[0], addr[1])
  150. else:
  151. clnt = etcd3.client(host=addr[0], port=addr[1])
  152. except Exception as e:
  153. clnt.close()
  154. raise e
  155. try:
  156. yield clnt
  157. except Exception as e:
  158. clnt.close()
  159. raise e
  160. else:
  161. with cls.ETCD_CLIENT_POOL_LOCK:
  162. if cls.ETCD_CLIENT_POOL_DESTROY:
  163. clnt.close()
  164. else:
  165. if name not in cls.ETCD_CLIENT_POOL:
  166. cls.ETCD_CLIENT_POOL[name] = [clnt]
  167. else:
  168. cls.ETCD_CLIENT_POOL[name].append(clnt)
  169. @classmethod
  170. def destory_client_pool(cls):
  171. with cls.ETCD_CLIENT_POOL_LOCK:
  172. cls.ETCD_CLIENT_POOL_DESTROY = True
  173. for _, clnts in cls.ETCD_CLIENT_POOL.items():
  174. for clnt in clnts:
  175. clnt.close()