_standalone.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import asyncio
  17. import hashlib
  18. import itertools
  19. import pickle as c_pickle
  20. import shutil
  21. import time
  22. import typing
  23. import uuid
  24. from collections import Iterable
  25. from concurrent.futures import ProcessPoolExecutor as Executor
  26. from contextlib import ExitStack
  27. from functools import partial
  28. from heapq import heapify, heappop, heapreplace
  29. from operator import is_not
  30. from pathlib import Path
  31. import cloudpickle as f_pickle
  32. import lmdb
  33. import numpy as np
  34. from fate_arch.common import Party, file_utils
  35. from fate_arch.common.log import getLogger
  36. from fate_arch.federation import FederationDataType
  37. LOGGER = getLogger()
  38. serialize = c_pickle.dumps
  39. deserialize = c_pickle.loads
  40. # default message max size in bytes = 1MB
  41. DEFAULT_MESSAGE_MAX_SIZE = 1048576
  42. # noinspection PyPep8Naming
  43. class Table(object):
  44. def __init__(
  45. self,
  46. session: "Session",
  47. namespace: str,
  48. name: str,
  49. partitions,
  50. need_cleanup=True,
  51. ):
  52. self._need_cleanup = need_cleanup
  53. self._namespace = namespace
  54. self._name = name
  55. self._partitions = partitions
  56. self._session = session
  57. @property
  58. def partitions(self):
  59. return self._partitions
  60. @property
  61. def name(self):
  62. return self._name
  63. @property
  64. def namespace(self):
  65. return self._namespace
  66. def __del__(self):
  67. if self._need_cleanup:
  68. self.destroy()
  69. def __str__(self):
  70. return f"<Table {self._namespace}|{self._name}|{self._partitions}|{self._need_cleanup}>"
  71. def __repr__(self):
  72. return self.__str__()
  73. def destroy(self):
  74. for p in range(self._partitions):
  75. with self._get_env_for_partition(p, write=True) as env:
  76. db = env.open_db()
  77. with env.begin(write=True) as txn:
  78. txn.drop(db)
  79. table_key = f"{self._namespace}.{self._name}"
  80. _get_meta_table().delete(table_key)
  81. path = _get_storage_dir(self._namespace, self._name)
  82. shutil.rmtree(path, ignore_errors=True)
  83. def take(self, n, **kwargs):
  84. if n <= 0:
  85. raise ValueError(f"{n} <= 0")
  86. return list(itertools.islice(self.collect(**kwargs), n))
  87. def count(self):
  88. cnt = 0
  89. for p in range(self._partitions):
  90. with self._get_env_for_partition(p) as env:
  91. cnt += env.stat()["entries"]
  92. return cnt
  93. # noinspection PyUnusedLocal
  94. def collect(self, **kwargs):
  95. iterators = []
  96. with ExitStack() as s:
  97. for p in range(self._partitions):
  98. env = s.enter_context(self._get_env_for_partition(p))
  99. txn = s.enter_context(env.begin())
  100. iterators.append(s.enter_context(txn.cursor()))
  101. # Merge sorted
  102. entries = []
  103. for _id, it in enumerate(iterators):
  104. if it.next():
  105. key, value = it.item()
  106. entries.append([key, value, _id, it])
  107. heapify(entries)
  108. while entries:
  109. key, value, _, it = entry = entries[0]
  110. yield deserialize(key), deserialize(value)
  111. if it.next():
  112. entry[0], entry[1] = it.item()
  113. heapreplace(entries, entry)
  114. else:
  115. _, _, _, it = heappop(entries)
  116. def reduce(self, func):
  117. # noinspection PyProtectedMember
  118. rs = self._session._submit_unary(
  119. func, _do_reduce, self._partitions, self._name, self._namespace
  120. )
  121. rs = [r for r in filter(partial(is_not, None), rs)]
  122. if len(rs) <= 0:
  123. return None
  124. rtn = rs[0]
  125. for r in rs[1:]:
  126. rtn = func(rtn, r)
  127. return rtn
  128. def map(self, func):
  129. return self._unary(func, _do_map)
  130. def mapValues(self, func):
  131. return self._unary(func, _do_map_values)
  132. def flatMap(self, func):
  133. _flat_mapped = self._unary(func, _do_flat_map)
  134. return _flat_mapped.save_as(
  135. name=str(uuid.uuid1()),
  136. namespace=_flat_mapped.namespace,
  137. partition=self._partitions,
  138. need_cleanup=True,
  139. )
  140. def applyPartitions(self, func):
  141. return self._unary(func, _do_apply_partitions)
  142. def mapPartitions(self, func, preserves_partitioning=False):
  143. un_shuffled = self._unary(func, _do_map_partitions)
  144. if preserves_partitioning:
  145. return un_shuffled
  146. return un_shuffled.save_as(
  147. name=str(uuid.uuid1()),
  148. namespace=un_shuffled.namespace,
  149. partition=self._partitions,
  150. need_cleanup=True,
  151. )
  152. def mapPartitionsWithIndex(self, func, preserves_partitioning=False):
  153. un_shuffled = self._unary(func, _do_map_partitions_with_index)
  154. if preserves_partitioning:
  155. return un_shuffled
  156. return un_shuffled.save_as(
  157. name=str(uuid.uuid1()),
  158. namespace=un_shuffled.namespace,
  159. partition=self._partitions,
  160. need_cleanup=True,
  161. )
  162. def mapReducePartitions(self, mapper, reducer):
  163. dup = _create_table(
  164. self._session,
  165. str(uuid.uuid1()),
  166. self.namespace,
  167. self._partitions,
  168. need_cleanup=True,
  169. )
  170. def _dict_reduce(a: dict, b: dict):
  171. for k, v in b.items():
  172. if k not in a:
  173. a[k] = v
  174. else:
  175. a[k] = reducer(a[k], v)
  176. return a
  177. def _local_map_reduce(it):
  178. ret = {}
  179. for _k, _v in mapper(it):
  180. if _k not in ret:
  181. ret[_k] = _v
  182. else:
  183. ret[_k] = reducer(ret[_k], _v)
  184. return ret
  185. dup.put_all(
  186. self.applyPartitions(_local_map_reduce).reduce(_dict_reduce).items()
  187. )
  188. return dup
  189. def glom(self):
  190. return self._unary(None, _do_glom)
  191. def sample(self, fraction, seed=None):
  192. return self._unary((fraction, seed), _do_sample)
  193. def filter(self, func):
  194. return self._unary(func, _do_filter)
  195. def join(self, other: "Table", func):
  196. return self._binary(other, func, _do_join)
  197. def subtractByKey(self, other: "Table"):
  198. func = f"{self._namespace}.{self._name}-{other._namespace}.{other._name}"
  199. return self._binary(other, func, _do_subtract_by_key)
  200. def union(self, other: "Table", func=lambda v1, v2: v1):
  201. return self._binary(other, func, _do_union)
  202. # noinspection PyProtectedMember
  203. def _map_reduce(self, mapper, reducer):
  204. results = self._session._submit_map_reduce(
  205. mapper, reducer, self._partitions, self._name, self._namespace
  206. )
  207. result = results[0]
  208. # noinspection PyProtectedMember
  209. return _create_table(
  210. session=self._session,
  211. name=result.name,
  212. namespace=result.namespace,
  213. partitions=self._partitions,
  214. )
  215. def _unary(self, func, do_func):
  216. # noinspection PyProtectedMember
  217. results = self._session._submit_unary(
  218. func, do_func, self._partitions, self._name, self._namespace
  219. )
  220. result = results[0]
  221. # noinspection PyProtectedMember
  222. return _create_table(
  223. session=self._session,
  224. name=result.name,
  225. namespace=result.namespace,
  226. partitions=self._partitions,
  227. )
  228. def _binary(self, other: "Table", func, do_func):
  229. session_id = self._session.session_id
  230. left, right = self, other
  231. if left._partitions != right._partitions:
  232. if other.count() > self.count():
  233. left = left.save_as(
  234. str(uuid.uuid1()), session_id, partition=right._partitions
  235. )
  236. else:
  237. right = other.save_as(
  238. str(uuid.uuid1()), session_id, partition=left._partitions
  239. )
  240. # noinspection PyProtectedMember
  241. results = self._session._submit_binary(
  242. func,
  243. do_func,
  244. left._partitions,
  245. left._name,
  246. left._namespace,
  247. right._name,
  248. right._namespace,
  249. )
  250. result: _Operand = results[0]
  251. # noinspection PyProtectedMember
  252. return _create_table(
  253. session=self._session,
  254. name=result.name,
  255. namespace=result.namespace,
  256. partitions=left._partitions,
  257. )
  258. def save_as(self, name, namespace, partition=None, need_cleanup=True):
  259. if partition is None:
  260. partition = self._partitions
  261. # noinspection PyProtectedMember
  262. dup = _create_table(self._session, name, namespace, partition, need_cleanup)
  263. dup.put_all(self.collect())
  264. return dup
  265. def _get_env_for_partition(self, p: int, write=False):
  266. return _get_env(self._namespace, self._name, str(p), write=write)
  267. def put(self, k, v):
  268. k_bytes, v_bytes = _kv_to_bytes(k=k, v=v)
  269. p = _hash_key_to_partition(k_bytes, self._partitions)
  270. with self._get_env_for_partition(p, write=True) as env:
  271. with env.begin(write=True) as txn:
  272. return txn.put(k_bytes, v_bytes)
  273. def put_all(self, kv_list: Iterable):
  274. txn_map = {}
  275. is_success = True
  276. with ExitStack() as s:
  277. for p in range(self._partitions):
  278. env = s.enter_context(self._get_env_for_partition(p, write=True))
  279. txn_map[p] = env, env.begin(write=True)
  280. for k, v in kv_list:
  281. try:
  282. k_bytes, v_bytes = _kv_to_bytes(k=k, v=v)
  283. p = _hash_key_to_partition(k_bytes, self._partitions)
  284. is_success = is_success and txn_map[p][1].put(k_bytes, v_bytes)
  285. except Exception as e:
  286. is_success = False
  287. LOGGER.exception(f"put_all for k={k} v={v} fail. exception: {e}")
  288. break
  289. for p, (env, txn) in txn_map.items():
  290. txn.commit() if is_success else txn.abort()
  291. def get(self, k):
  292. k_bytes = _k_to_bytes(k=k)
  293. p = _hash_key_to_partition(k_bytes, self._partitions)
  294. with self._get_env_for_partition(p) as env:
  295. with env.begin(write=True) as txn:
  296. old_value_bytes = txn.get(k_bytes)
  297. return (
  298. None if old_value_bytes is None else deserialize(old_value_bytes)
  299. )
  300. def delete(self, k):
  301. k_bytes = _k_to_bytes(k=k)
  302. p = _hash_key_to_partition(k_bytes, self._partitions)
  303. with self._get_env_for_partition(p, write=True) as env:
  304. with env.begin(write=True) as txn:
  305. old_value_bytes = txn.get(k_bytes)
  306. if txn.delete(k_bytes):
  307. return (
  308. None
  309. if old_value_bytes is None
  310. else deserialize(old_value_bytes)
  311. )
  312. return None
  313. # noinspection PyMethodMayBeStatic
  314. class Session(object):
  315. def __init__(self, session_id, max_workers=None):
  316. self.session_id = session_id
  317. self._pool = Executor(max_workers=max_workers)
  318. def __getstate__(self):
  319. # session won't be pickled
  320. pass
  321. def load(self, name, namespace):
  322. return _load_table(session=self, name=name, namespace=namespace)
  323. def create_table(self, name, namespace, partitions, need_cleanup, error_if_exist):
  324. return _create_table(
  325. session=self,
  326. name=name,
  327. namespace=namespace,
  328. partitions=partitions,
  329. need_cleanup=need_cleanup,
  330. error_if_exist=error_if_exist,
  331. )
  332. # noinspection PyUnusedLocal
  333. def parallelize(
  334. self, data: Iterable, partition: int, include_key: bool = False, **kwargs
  335. ):
  336. if not include_key:
  337. data = enumerate(data)
  338. table = _create_table(
  339. session=self,
  340. name=str(uuid.uuid1()),
  341. namespace=self.session_id,
  342. partitions=partition,
  343. )
  344. table.put_all(data)
  345. return table
  346. def cleanup(self, name, namespace):
  347. data_path = _get_data_dir()
  348. if not data_path.is_dir():
  349. LOGGER.error(f"illegal data dir: {data_path}")
  350. return
  351. namespace_dir = data_path.joinpath(namespace)
  352. if not namespace_dir.is_dir():
  353. return
  354. if name == "*":
  355. shutil.rmtree(namespace_dir, True)
  356. return
  357. for table in namespace_dir.glob(name):
  358. shutil.rmtree(table, True)
  359. def stop(self):
  360. self.cleanup(name="*", namespace=self.session_id)
  361. self._pool.shutdown()
  362. def kill(self):
  363. self.cleanup(name="*", namespace=self.session_id)
  364. self._pool.shutdown()
  365. def _submit_unary(self, func, _do_func, partitions, name, namespace):
  366. task_info = _TaskInfo(
  367. self.session_id,
  368. function_id=str(uuid.uuid1()),
  369. function_bytes=f_pickle.dumps(func),
  370. )
  371. futures = []
  372. for p in range(partitions):
  373. futures.append(
  374. self._pool.submit(
  375. _do_func, _UnaryProcess(task_info, _Operand(namespace, name, p))
  376. )
  377. )
  378. results = [r.result() for r in futures]
  379. return results
  380. def _submit_map_reduce_in_partition(
  381. self, mapper, reducer, partitions, name, namespace
  382. ):
  383. task_info = _MapReduceTaskInfo(
  384. self.session_id,
  385. function_id=str(uuid.uuid1()),
  386. map_function_bytes=f_pickle.dumps(mapper),
  387. reduce_function_bytes=f_pickle.dumps(reducer),
  388. )
  389. futures = []
  390. for p in range(partitions):
  391. futures.append(
  392. self._pool.submit(
  393. _do_map_reduce_in_partitions,
  394. _MapReduceProcess(task_info, _Operand(namespace, name, p)),
  395. )
  396. )
  397. results = [r.result() for r in futures]
  398. return results
  399. def _submit_binary(
  400. self, func, do_func, partitions, name, namespace, other_name, other_namespace
  401. ):
  402. task_info = _TaskInfo(
  403. self.session_id,
  404. function_id=str(uuid.uuid1()),
  405. function_bytes=f_pickle.dumps(func),
  406. )
  407. futures = []
  408. for p in range(partitions):
  409. left = _Operand(namespace, name, p)
  410. right = _Operand(other_namespace, other_name, p)
  411. futures.append(
  412. self._pool.submit(do_func, _BinaryProcess(task_info, left, right))
  413. )
  414. results = [r.result() for r in futures]
  415. return results
  416. def _get_splits(obj, max_message_size):
  417. obj_bytes = serialize(obj, protocol=4)
  418. byte_size = len(obj_bytes)
  419. num_slice = (byte_size - 1) // max_message_size + 1
  420. if num_slice <= 1:
  421. return obj, num_slice
  422. else:
  423. _max_size = max_message_size
  424. kv = [(i, obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) for i in range(num_slice)]
  425. return kv, num_slice
  426. class Federation(object):
  427. def _federation_object_key(self, name, tag, s_party, d_party):
  428. return f"{self._session_id}-{name}-{tag}-{s_party.role}-{s_party.party_id}-{d_party.role}-{d_party.party_id}"
  429. def __init__(self, session: Session, session_id, party: Party):
  430. self._session_id = session_id
  431. self._party: Party = party
  432. self._session = session
  433. self._max_message_size = DEFAULT_MESSAGE_MAX_SIZE
  434. self._other_status_tables = {}
  435. self._other_object_tables = {}
  436. self._even_loop = None
  437. self._federation_status_table_cache = None
  438. self._federation_object_table_cache = None
  439. def destroy(self):
  440. self._session.cleanup(namespace=self._session_id, name="*")
  441. @property
  442. def _federation_status_table(self):
  443. if self._federation_status_table_cache is None:
  444. self._federation_status_table_cache = _create_table(
  445. session=self._session,
  446. name=self._get_status_table_name(self._party),
  447. namespace=self._session_id,
  448. partitions=1,
  449. need_cleanup=True,
  450. error_if_exist=False,
  451. )
  452. return self._federation_status_table_cache
  453. @property
  454. def _federation_object_table(self):
  455. if self._federation_object_table_cache is None:
  456. self._federation_object_table_cache = _create_table(
  457. session=self._session,
  458. name=self._get_object_table_name(self._party),
  459. namespace=self._session_id,
  460. partitions=1,
  461. need_cleanup=True,
  462. error_if_exist=False,
  463. )
  464. return self._federation_object_table_cache
  465. @property
  466. def _loop(self):
  467. if self._even_loop is None:
  468. self._even_loop = asyncio.get_event_loop()
  469. return self._even_loop
  470. @staticmethod
  471. def _get_status_table_name(party):
  472. return f"__federation_status__.{party.role}_{party.party_id}"
  473. @staticmethod
  474. def _get_object_table_name(party):
  475. return f"__federation_object__.{party.role}_{party.party_id}"
  476. def _get_other_status_table(self, party):
  477. if party in self._other_status_tables:
  478. return self._other_status_tables[party]
  479. table = _create_table(
  480. self._session,
  481. name=self._get_status_table_name(party),
  482. namespace=self._session_id,
  483. partitions=1,
  484. need_cleanup=False,
  485. error_if_exist=False,
  486. )
  487. self._other_status_tables[party] = table
  488. return table
  489. def _get_other_object_table(self, party):
  490. if party in self._other_object_tables:
  491. return self._other_object_tables[party]
  492. table = _create_table(
  493. self._session,
  494. name=self._get_object_table_name(party),
  495. namespace=self._session_id,
  496. partitions=1,
  497. need_cleanup=False,
  498. error_if_exist=False,
  499. )
  500. self._other_object_tables[party] = table
  501. return table
  502. # noinspection PyProtectedMember
  503. def _put_status(self, party, _tagged_key, value):
  504. self._get_other_status_table(party).put(_tagged_key, value)
  505. # noinspection PyProtectedMember
  506. def _put_object(self, party, _tagged_key, value):
  507. self._get_other_object_table(party).put(_tagged_key, value)
  508. # noinspection PyProtectedMember
  509. def _get_object(self, _tagged_key):
  510. return self._federation_object_table.get(_tagged_key)
  511. # noinspection PyProtectedMember
  512. def _get_status(self, _tagged_key):
  513. return self._federation_status_table.get(_tagged_key)
  514. # noinspection PyUnusedLocal
  515. def remote(self, v, name: str, tag: str, parties: typing.List[Party]):
  516. log_str = f"federation.standalone.remote.{name}.{tag}"
  517. if v is None:
  518. raise ValueError(f"[{log_str}]remote `None` to {parties}")
  519. LOGGER.debug(f"[{log_str}]remote data, type={type(v)}")
  520. if isinstance(v, Table):
  521. dtype = FederationDataType.TABLE
  522. LOGGER.debug(
  523. f"[{log_str}]remote "
  524. f"Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}), dtype={dtype}"
  525. )
  526. else:
  527. v_splits, num_slice = _get_splits(v, self._max_message_size)
  528. if num_slice > 1:
  529. v = _create_table(
  530. session=self._session,
  531. name=str(uuid.uuid1()),
  532. namespace=self._session_id,
  533. partitions=1,
  534. need_cleanup=True,
  535. error_if_exist=False,
  536. )
  537. v.put_all(kv_list=v_splits)
  538. dtype = FederationDataType.SPLIT_OBJECT
  539. LOGGER.debug(
  540. f"[{log_str}]remote "
  541. f"Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}), dtype={dtype}"
  542. )
  543. else:
  544. LOGGER.debug(f"[{log_str}]remote object with type: {type(v)}")
  545. dtype = FederationDataType.OBJECT
  546. for party in parties:
  547. _tagged_key = self._federation_object_key(name, tag, self._party, party)
  548. if isinstance(v, Table):
  549. saved_name = str(uuid.uuid1())
  550. LOGGER.debug(
  551. f"[{log_str}]save Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}) as "
  552. f"Table(namespace={v.namespace}, name={saved_name}, partitions={v.partitions})"
  553. )
  554. _v = v.save_as(
  555. name=saved_name, namespace=v.namespace, need_cleanup=False
  556. )
  557. self._put_status(party, _tagged_key, (_v.name, _v.namespace, dtype))
  558. else:
  559. self._put_object(party, _tagged_key, v)
  560. self._put_status(party, _tagged_key, _tagged_key)
  561. # noinspection PyProtectedMember
  562. def get(self, name: str, tag: str, parties: typing.List[Party]) -> typing.List:
  563. log_str = f"federation.standalone.get.{name}.{tag}"
  564. LOGGER.debug(f"[{log_str}]")
  565. tasks = []
  566. for party in parties:
  567. _tagged_key = self._federation_object_key(name, tag, party, self._party)
  568. tasks.append(_check_status_and_get_value(self._get_status, _tagged_key))
  569. results = self._loop.run_until_complete(asyncio.gather(*tasks))
  570. rtn = []
  571. for r in results:
  572. if isinstance(r, tuple):
  573. # noinspection PyTypeChecker
  574. table: Table = _load_table(
  575. session=self._session, name=r[0], namespace=r[1], need_cleanup=True
  576. )
  577. dtype = r[2]
  578. LOGGER.debug(
  579. f"[{log_str}] got "
  580. f"Table(namespace={table.namespace}, name={table.name}, partitions={table.partitions}), dtype={dtype}")
  581. if dtype == FederationDataType.SPLIT_OBJECT:
  582. obj_bytes = b''.join(map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0])))
  583. obj = deserialize(obj_bytes)
  584. rtn.append(obj)
  585. else:
  586. rtn.append(table)
  587. else:
  588. obj = self._get_object(r)
  589. if obj is None:
  590. raise EnvironmentError(
  591. f"federation get None from {parties} with name {name}, tag {tag}"
  592. )
  593. rtn.append(obj)
  594. self._federation_object_table.delete(k=r)
  595. LOGGER.debug(f"[{log_str}] got object with type: {type(obj)}")
  596. self._federation_status_table.delete(r)
  597. return rtn
  598. _meta_table: typing.Optional[Table] = None
  599. _SESSION = Session(uuid.uuid1().hex)
  600. def _get_meta_table():
  601. global _meta_table
  602. if _meta_table is None:
  603. _meta_table = Table(
  604. _SESSION,
  605. namespace="__META__",
  606. name="fragments",
  607. partitions=10,
  608. need_cleanup=False,
  609. )
  610. return _meta_table
  611. # noinspection PyProtectedMember
  612. def _get_from_meta_table(key):
  613. return _get_meta_table().get(key)
  614. # noinspection PyProtectedMember
  615. def _put_to_meta_table(key, value):
  616. _get_meta_table().put(key, value)
  617. _data_dir = Path(file_utils.get_project_base_directory()).joinpath("data").absolute()
  618. def _get_data_dir():
  619. return _data_dir
  620. def _get_storage_dir(*args):
  621. return _data_dir.joinpath(*args)
  622. async def _check_status_and_get_value(get_func, key):
  623. value = get_func(key)
  624. while value is None:
  625. await asyncio.sleep(0.1)
  626. value = get_func(key)
  627. LOGGER.debug(
  628. "[GET] Got {} type {}".format(
  629. key, "Table" if isinstance(value, tuple) else "Object"
  630. )
  631. )
  632. return value
  633. def _create_table(
  634. session: "Session",
  635. name: str,
  636. namespace: str,
  637. partitions: int,
  638. need_cleanup=True,
  639. error_if_exist=False,
  640. ):
  641. if isinstance(namespace, int):
  642. raise ValueError(f"{namespace} {name}")
  643. _table_key = ".".join([namespace, name])
  644. if _get_from_meta_table(_table_key) is not None:
  645. if error_if_exist:
  646. raise RuntimeError(
  647. f"table already exist: name={name}, namespace={namespace}"
  648. )
  649. else:
  650. partitions = _get_from_meta_table(_table_key)
  651. else:
  652. _put_to_meta_table(_table_key, partitions)
  653. return Table(
  654. session=session,
  655. namespace=namespace,
  656. name=name,
  657. partitions=partitions,
  658. need_cleanup=need_cleanup,
  659. )
  660. def _exist(name: str, namespace: str):
  661. _table_key = ".".join([namespace, name])
  662. return _get_from_meta_table(_table_key) is not None
  663. def _load_table(session, name, namespace, need_cleanup=False):
  664. _table_key = ".".join([namespace, name])
  665. partitions = _get_from_meta_table(_table_key)
  666. if partitions is None:
  667. raise RuntimeError(f"table not exist: name={name}, namespace={namespace}")
  668. return Table(
  669. session=session,
  670. namespace=namespace,
  671. name=name,
  672. partitions=partitions,
  673. need_cleanup=need_cleanup,
  674. )
  675. class _TaskInfo:
  676. def __init__(self, task_id, function_id, function_bytes):
  677. self.task_id = task_id
  678. self.function_id = function_id
  679. self.function_bytes = function_bytes
  680. self._function_deserialized = None
  681. def get_func(self):
  682. if self._function_deserialized is None:
  683. self._function_deserialized = f_pickle.loads(self.function_bytes)
  684. return self._function_deserialized
  685. class _MapReduceTaskInfo:
  686. def __init__(self, task_id, function_id, map_function_bytes, reduce_function_bytes):
  687. self.task_id = task_id
  688. self.function_id = function_id
  689. self.map_function_bytes = map_function_bytes
  690. self.reduce_function_bytes = reduce_function_bytes
  691. self._reduce_function_deserialized = None
  692. self._mapper_function_deserialized = None
  693. def get_mapper(self):
  694. if self._mapper_function_deserialized is None:
  695. self._mapper_function_deserialized = f_pickle.loads(self.map_function_bytes)
  696. return self._mapper_function_deserialized
  697. def get_reducer(self):
  698. if self._reduce_function_deserialized is None:
  699. self._reduce_function_deserialized = f_pickle.loads(
  700. self.reduce_function_bytes
  701. )
  702. return self._reduce_function_deserialized
  703. class _Operand:
  704. def __init__(self, namespace, name, partition):
  705. self.namespace = namespace
  706. self.name = name
  707. self.partition = partition
  708. def as_env(self, write=False):
  709. return _get_env(self.namespace, self.name, str(self.partition), write=write)
  710. class _UnaryProcess:
  711. def __init__(self, task_info: _TaskInfo, operand: _Operand):
  712. self.info = task_info
  713. self.operand = operand
  714. def output_operand(self):
  715. return _Operand(
  716. self.info.task_id, self.info.function_id, self.operand.partition
  717. )
  718. def get_func(self):
  719. return self.info.get_func()
  720. class _MapReduceProcess:
  721. def __init__(self, task_info: _MapReduceTaskInfo, operand: _Operand):
  722. self.info = task_info
  723. self.operand = operand
  724. def output_operand(self):
  725. return _Operand(
  726. self.info.task_id, self.info.function_id, self.operand.partition
  727. )
  728. def get_mapper(self):
  729. return self.info.get_mapper()
  730. def get_reducer(self):
  731. return self.info.get_reducer()
  732. class _BinaryProcess:
  733. def __init__(self, task_info: _TaskInfo, left: _Operand, right: _Operand):
  734. self.info = task_info
  735. self.left = left
  736. self.right = right
  737. def output_operand(self):
  738. return _Operand(self.info.task_id, self.info.function_id, self.left.partition)
  739. def get_func(self):
  740. return self.info.get_func()
  741. def _get_env(*args, write=False):
  742. _path = _get_storage_dir(*args)
  743. return _open_env(_path, write=write)
  744. def _open_env(path, write=False):
  745. path.mkdir(parents=True, exist_ok=True)
  746. t = 0
  747. while t < 100:
  748. try:
  749. env = lmdb.open(
  750. path.as_posix(),
  751. create=True,
  752. max_dbs=1,
  753. max_readers=1024,
  754. lock=write,
  755. sync=True,
  756. map_size=10_737_418_240,
  757. )
  758. return env
  759. except lmdb.Error as e:
  760. if "No such file or directory" in e.args[0]:
  761. time.sleep(0.01)
  762. t += 1
  763. else:
  764. raise e
  765. raise lmdb.Error(f"No such file or directory: {path}, with {t} times retry")
  766. def _hash_key_to_partition(key, partitions):
  767. _key = hashlib.sha1(key).digest()
  768. if isinstance(_key, bytes):
  769. _key = int.from_bytes(_key, byteorder="little", signed=False)
  770. if partitions < 1:
  771. raise ValueError("partitions must be a positive number")
  772. b, j = -1, 0
  773. while j < partitions:
  774. b = int(j)
  775. _key = ((_key * 2862933555777941757) + 1) & 0xFFFFFFFFFFFFFFFF
  776. j = float(b + 1) * (float(1 << 31) / float((_key >> 33) + 1))
  777. return int(b)
  778. def _do_map(p: _UnaryProcess):
  779. rtn = p.output_operand()
  780. with ExitStack() as s:
  781. source_env = s.enter_context(p.operand.as_env())
  782. partitions = _get_from_meta_table(f"{p.operand.namespace}.{p.operand.name}")
  783. txn_map = {}
  784. for partition in range(partitions):
  785. env = s.enter_context(
  786. _get_env(rtn.namespace, rtn.name, str(partition), write=True)
  787. )
  788. txn_map[partition] = s.enter_context(env.begin(write=True))
  789. source_txn = s.enter_context(source_env.begin())
  790. cursor = s.enter_context(source_txn.cursor())
  791. for k_bytes, v_bytes in cursor:
  792. k, v = deserialize(k_bytes), deserialize(v_bytes)
  793. k1, v1 = p.get_func()(k, v)
  794. k1_bytes, v1_bytes = serialize(k1), serialize(v1)
  795. partition = _hash_key_to_partition(k1_bytes, partitions)
  796. txn_map[partition].put(k1_bytes, v1_bytes)
  797. return rtn
  798. def _generator_from_cursor(cursor):
  799. for k, v in cursor:
  800. yield deserialize(k), deserialize(v)
  801. def _do_apply_partitions(p: _UnaryProcess):
  802. with ExitStack() as s:
  803. rtn = p.output_operand()
  804. source_env = s.enter_context(p.operand.as_env())
  805. dst_env = s.enter_context(rtn.as_env(write=True))
  806. source_txn = s.enter_context(source_env.begin())
  807. dst_txn = s.enter_context(dst_env.begin(write=True))
  808. cursor = s.enter_context(source_txn.cursor())
  809. v = p.get_func()(_generator_from_cursor(cursor))
  810. if cursor.last():
  811. k_bytes = cursor.key()
  812. dst_txn.put(k_bytes, serialize(v))
  813. return rtn
  814. def _do_map_partitions(p: _UnaryProcess):
  815. with ExitStack() as s:
  816. rtn = p.output_operand()
  817. source_env = s.enter_context(p.operand.as_env())
  818. dst_env = s.enter_context(rtn.as_env(write=True))
  819. source_txn = s.enter_context(source_env.begin())
  820. dst_txn = s.enter_context(dst_env.begin(write=True))
  821. cursor = s.enter_context(source_txn.cursor())
  822. v = p.get_func()(_generator_from_cursor(cursor))
  823. if isinstance(v, Iterable):
  824. for k1, v1 in v:
  825. dst_txn.put(serialize(k1), serialize(v1))
  826. else:
  827. k_bytes = cursor.key()
  828. dst_txn.put(k_bytes, serialize(v))
  829. return rtn
  830. def _do_map_partitions_with_index(p: _UnaryProcess):
  831. with ExitStack() as s:
  832. rtn = p.output_operand()
  833. source_env = s.enter_context(p.operand.as_env())
  834. dst_env = s.enter_context(rtn.as_env(write=True))
  835. source_txn = s.enter_context(source_env.begin())
  836. dst_txn = s.enter_context(dst_env.begin(write=True))
  837. cursor = s.enter_context(source_txn.cursor())
  838. v = p.get_func()(p.operand.partition, _generator_from_cursor(cursor))
  839. if isinstance(v, Iterable):
  840. for k1, v1 in v:
  841. dst_txn.put(serialize(k1), serialize(v1))
  842. else:
  843. k_bytes = cursor.key()
  844. dst_txn.put(k_bytes, serialize(v))
  845. return rtn
  846. def _do_map_reduce_in_partitions(p: _MapReduceProcess):
  847. rtn = p.output_operand()
  848. with ExitStack() as s:
  849. source_env = s.enter_context(p.operand.as_env())
  850. partitions = _get_from_meta_table(f"{p.operand.namespace}.{p.operand.name}")
  851. txn_map = {}
  852. for partition in range(partitions):
  853. env = s.enter_context(
  854. _get_env(rtn.namespace, rtn.name, str(partition), write=True)
  855. )
  856. txn_map[partition] = s.enter_context(env.begin(write=True))
  857. source_txn = s.enter_context(source_env.begin())
  858. cursor = s.enter_context(source_txn.cursor())
  859. mapped = p.get_mapper()(_generator_from_cursor(cursor))
  860. if not isinstance(mapped, Iterable):
  861. raise ValueError("mapper function should return a iterable of pair")
  862. reducer = p.get_reducer()
  863. for k, v in mapped:
  864. k_bytes = serialize(k)
  865. partition = _hash_key_to_partition(k_bytes, partitions)
  866. # todo: not atomic, fix me
  867. pre_v = txn_map[partition].get(k_bytes, None)
  868. if pre_v is None:
  869. txn_map[partition].put(k_bytes, serialize(v))
  870. else:
  871. txn_map[partition].put(
  872. k_bytes, serialize(reducer(deserialize(pre_v), v))
  873. )
  874. return rtn
  875. def _do_map_values(p: _UnaryProcess):
  876. rtn = p.output_operand()
  877. with ExitStack() as s:
  878. source_env = s.enter_context(p.operand.as_env())
  879. dst_env = s.enter_context(rtn.as_env(write=True))
  880. source_txn = s.enter_context(source_env.begin())
  881. dst_txn = s.enter_context(dst_env.begin(write=True))
  882. cursor = s.enter_context(source_txn.cursor())
  883. for k_bytes, v_bytes in cursor:
  884. v = deserialize(v_bytes)
  885. v1 = p.get_func()(v)
  886. dst_txn.put(k_bytes, serialize(v1))
  887. return rtn
  888. def _do_flat_map(p: _UnaryProcess):
  889. rtn = p.output_operand()
  890. with ExitStack() as s:
  891. source_env = s.enter_context(p.operand.as_env())
  892. dst_env = s.enter_context(rtn.as_env(write=True))
  893. source_txn = s.enter_context(source_env.begin())
  894. dst_txn = s.enter_context(dst_env.begin(write=True))
  895. cursor = s.enter_context(source_txn.cursor())
  896. for k_bytes, v_bytes in cursor:
  897. k = deserialize(k_bytes)
  898. v = deserialize(v_bytes)
  899. map_result = p.get_func()(k, v)
  900. for result_k, result_v in map_result:
  901. dst_txn.put(serialize(result_k), serialize(result_v))
  902. return rtn
  903. def _do_reduce(p: _UnaryProcess):
  904. value = None
  905. with ExitStack() as s:
  906. source_env = s.enter_context(p.operand.as_env())
  907. source_txn = s.enter_context(source_env.begin())
  908. cursor = s.enter_context(source_txn.cursor())
  909. for k_bytes, v_bytes in cursor:
  910. v = deserialize(v_bytes)
  911. if value is None:
  912. value = v
  913. else:
  914. value = p.get_func()(value, v)
  915. return value
  916. def _do_glom(p: _UnaryProcess):
  917. rtn = p.output_operand()
  918. with ExitStack() as s:
  919. source_env = s.enter_context(p.operand.as_env())
  920. dst_env = s.enter_context(rtn.as_env(write=True))
  921. source_txn = s.enter_context(source_env.begin())
  922. dest_txn = s.enter_context(dst_env.begin(write=True))
  923. cursor = s.enter_context(source_txn.cursor())
  924. v_list = []
  925. k_bytes = None
  926. for k, v in cursor:
  927. v_list.append((deserialize(k), deserialize(v)))
  928. k_bytes = k
  929. if k_bytes is not None:
  930. dest_txn.put(k_bytes, serialize(v_list))
  931. return rtn
  932. def _do_sample(p: _UnaryProcess):
  933. rtn = p.output_operand()
  934. fraction, seed = deserialize(p.info.function_bytes)
  935. with ExitStack() as s:
  936. source_env = s.enter_context(p.operand.as_env())
  937. dst_env = s.enter_context(rtn.as_env(write=True))
  938. source_txn = s.enter_context(source_env.begin())
  939. dst_txn = s.enter_context(dst_env.begin(write=True))
  940. cursor = s.enter_context(source_txn.cursor())
  941. cursor.first()
  942. random_state = np.random.RandomState(seed)
  943. for k, v in cursor:
  944. # noinspection PyArgumentList
  945. if random_state.rand() < fraction:
  946. dst_txn.put(k, v)
  947. return rtn
  948. def _do_filter(p: _UnaryProcess):
  949. rtn = p.output_operand()
  950. with ExitStack() as s:
  951. source_env = s.enter_context(p.operand.as_env())
  952. dst_env = s.enter_context(rtn.as_env(write=True))
  953. source_txn = s.enter_context(source_env.begin())
  954. dst_txn = s.enter_context(dst_env.begin(write=True))
  955. cursor = s.enter_context(source_txn.cursor())
  956. for k_bytes, v_bytes in cursor:
  957. k = deserialize(k_bytes)
  958. v = deserialize(v_bytes)
  959. if p.get_func()(k, v):
  960. dst_txn.put(k_bytes, v_bytes)
  961. return rtn
  962. def _do_subtract_by_key(p: _BinaryProcess):
  963. rtn = p.output_operand()
  964. with ExitStack() as s:
  965. left_op = p.left
  966. right_op = p.right
  967. right_env = s.enter_context(right_op.as_env())
  968. left_env = s.enter_context(left_op.as_env())
  969. dst_env = s.enter_context(rtn.as_env(write=True))
  970. left_txn = s.enter_context(left_env.begin())
  971. right_txn = s.enter_context(right_env.begin())
  972. dst_txn = s.enter_context(dst_env.begin(write=True))
  973. cursor = s.enter_context(left_txn.cursor())
  974. for k_bytes, left_v_bytes in cursor:
  975. right_v_bytes = right_txn.get(k_bytes)
  976. if right_v_bytes is None:
  977. dst_txn.put(k_bytes, left_v_bytes)
  978. return rtn
  979. def _do_join(p: _BinaryProcess):
  980. rtn = p.output_operand()
  981. with ExitStack() as s:
  982. right_env = s.enter_context(p.right.as_env())
  983. left_env = s.enter_context(p.left.as_env())
  984. dst_env = s.enter_context(rtn.as_env(write=True))
  985. left_txn = s.enter_context(left_env.begin())
  986. right_txn = s.enter_context(right_env.begin())
  987. dst_txn = s.enter_context(dst_env.begin(write=True))
  988. cursor = s.enter_context(left_txn.cursor())
  989. for k_bytes, v1_bytes in cursor:
  990. v2_bytes = right_txn.get(k_bytes)
  991. if v2_bytes is None:
  992. continue
  993. v1 = deserialize(v1_bytes)
  994. v2 = deserialize(v2_bytes)
  995. v3 = p.get_func()(v1, v2)
  996. dst_txn.put(k_bytes, serialize(v3))
  997. return rtn
  998. def _do_union(p: _BinaryProcess):
  999. rtn = p.output_operand()
  1000. with ExitStack() as s:
  1001. left_env = s.enter_context(p.left.as_env())
  1002. right_env = s.enter_context(p.right.as_env())
  1003. dst_env = s.enter_context(rtn.as_env(write=True))
  1004. left_txn = s.enter_context(left_env.begin())
  1005. right_txn = s.enter_context(right_env.begin())
  1006. dst_txn = s.enter_context(dst_env.begin(write=True))
  1007. # process left op
  1008. with left_txn.cursor() as left_cursor:
  1009. for k_bytes, left_v_bytes in left_cursor:
  1010. right_v_bytes = right_txn.get(k_bytes)
  1011. if right_v_bytes is None:
  1012. dst_txn.put(k_bytes, left_v_bytes)
  1013. else:
  1014. left_v = deserialize(left_v_bytes)
  1015. right_v = deserialize(right_v_bytes)
  1016. final_v = p.get_func()(left_v, right_v)
  1017. dst_txn.put(k_bytes, serialize(final_v))
  1018. # process right op
  1019. with right_txn.cursor() as right_cursor:
  1020. for k_bytes, right_v_bytes in right_cursor:
  1021. final_v_bytes = dst_txn.get(k_bytes)
  1022. if final_v_bytes is None:
  1023. dst_txn.put(k_bytes, right_v_bytes)
  1024. return rtn
  1025. def _kv_to_bytes(k, v):
  1026. return serialize(k), serialize(v)
  1027. def _k_to_bytes(k):
  1028. return serialize(k)