123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import asyncio
- import hashlib
- import itertools
- import pickle as c_pickle
- import shutil
- import time
- import typing
- import uuid
- from collections import Iterable
- from concurrent.futures import ProcessPoolExecutor as Executor
- from contextlib import ExitStack
- from functools import partial
- from heapq import heapify, heappop, heapreplace
- from operator import is_not
- from pathlib import Path
- import cloudpickle as f_pickle
- import lmdb
- import numpy as np
- from fate_arch.common import Party, file_utils
- from fate_arch.common.log import getLogger
- from fate_arch.federation import FederationDataType
- LOGGER = getLogger()
- serialize = c_pickle.dumps
- deserialize = c_pickle.loads
- # default message max size in bytes = 1MB
- DEFAULT_MESSAGE_MAX_SIZE = 1048576
- # noinspection PyPep8Naming
- class Table(object):
- def __init__(
- self,
- session: "Session",
- namespace: str,
- name: str,
- partitions,
- need_cleanup=True,
- ):
- self._need_cleanup = need_cleanup
- self._namespace = namespace
- self._name = name
- self._partitions = partitions
- self._session = session
- @property
- def partitions(self):
- return self._partitions
- @property
- def name(self):
- return self._name
- @property
- def namespace(self):
- return self._namespace
- def __del__(self):
- if self._need_cleanup:
- self.destroy()
- def __str__(self):
- return f"<Table {self._namespace}|{self._name}|{self._partitions}|{self._need_cleanup}>"
- def __repr__(self):
- return self.__str__()
- def destroy(self):
- for p in range(self._partitions):
- with self._get_env_for_partition(p, write=True) as env:
- db = env.open_db()
- with env.begin(write=True) as txn:
- txn.drop(db)
- table_key = f"{self._namespace}.{self._name}"
- _get_meta_table().delete(table_key)
- path = _get_storage_dir(self._namespace, self._name)
- shutil.rmtree(path, ignore_errors=True)
- def take(self, n, **kwargs):
- if n <= 0:
- raise ValueError(f"{n} <= 0")
- return list(itertools.islice(self.collect(**kwargs), n))
- def count(self):
- cnt = 0
- for p in range(self._partitions):
- with self._get_env_for_partition(p) as env:
- cnt += env.stat()["entries"]
- return cnt
- # noinspection PyUnusedLocal
- def collect(self, **kwargs):
- iterators = []
- with ExitStack() as s:
- for p in range(self._partitions):
- env = s.enter_context(self._get_env_for_partition(p))
- txn = s.enter_context(env.begin())
- iterators.append(s.enter_context(txn.cursor()))
- # Merge sorted
- entries = []
- for _id, it in enumerate(iterators):
- if it.next():
- key, value = it.item()
- entries.append([key, value, _id, it])
- heapify(entries)
- while entries:
- key, value, _, it = entry = entries[0]
- yield deserialize(key), deserialize(value)
- if it.next():
- entry[0], entry[1] = it.item()
- heapreplace(entries, entry)
- else:
- _, _, _, it = heappop(entries)
- def reduce(self, func):
- # noinspection PyProtectedMember
- rs = self._session._submit_unary(
- func, _do_reduce, self._partitions, self._name, self._namespace
- )
- rs = [r for r in filter(partial(is_not, None), rs)]
- if len(rs) <= 0:
- return None
- rtn = rs[0]
- for r in rs[1:]:
- rtn = func(rtn, r)
- return rtn
- def map(self, func):
- return self._unary(func, _do_map)
- def mapValues(self, func):
- return self._unary(func, _do_map_values)
- def flatMap(self, func):
- _flat_mapped = self._unary(func, _do_flat_map)
- return _flat_mapped.save_as(
- name=str(uuid.uuid1()),
- namespace=_flat_mapped.namespace,
- partition=self._partitions,
- need_cleanup=True,
- )
- def applyPartitions(self, func):
- return self._unary(func, _do_apply_partitions)
- def mapPartitions(self, func, preserves_partitioning=False):
- un_shuffled = self._unary(func, _do_map_partitions)
- if preserves_partitioning:
- return un_shuffled
- return un_shuffled.save_as(
- name=str(uuid.uuid1()),
- namespace=un_shuffled.namespace,
- partition=self._partitions,
- need_cleanup=True,
- )
- def mapPartitionsWithIndex(self, func, preserves_partitioning=False):
- un_shuffled = self._unary(func, _do_map_partitions_with_index)
- if preserves_partitioning:
- return un_shuffled
- return un_shuffled.save_as(
- name=str(uuid.uuid1()),
- namespace=un_shuffled.namespace,
- partition=self._partitions,
- need_cleanup=True,
- )
- def mapReducePartitions(self, mapper, reducer):
- dup = _create_table(
- self._session,
- str(uuid.uuid1()),
- self.namespace,
- self._partitions,
- need_cleanup=True,
- )
- def _dict_reduce(a: dict, b: dict):
- for k, v in b.items():
- if k not in a:
- a[k] = v
- else:
- a[k] = reducer(a[k], v)
- return a
- def _local_map_reduce(it):
- ret = {}
- for _k, _v in mapper(it):
- if _k not in ret:
- ret[_k] = _v
- else:
- ret[_k] = reducer(ret[_k], _v)
- return ret
- dup.put_all(
- self.applyPartitions(_local_map_reduce).reduce(_dict_reduce).items()
- )
- return dup
- def glom(self):
- return self._unary(None, _do_glom)
- def sample(self, fraction, seed=None):
- return self._unary((fraction, seed), _do_sample)
- def filter(self, func):
- return self._unary(func, _do_filter)
- def join(self, other: "Table", func):
- return self._binary(other, func, _do_join)
- def subtractByKey(self, other: "Table"):
- func = f"{self._namespace}.{self._name}-{other._namespace}.{other._name}"
- return self._binary(other, func, _do_subtract_by_key)
- def union(self, other: "Table", func=lambda v1, v2: v1):
- return self._binary(other, func, _do_union)
- # noinspection PyProtectedMember
- def _map_reduce(self, mapper, reducer):
- results = self._session._submit_map_reduce(
- mapper, reducer, self._partitions, self._name, self._namespace
- )
- result = results[0]
- # noinspection PyProtectedMember
- return _create_table(
- session=self._session,
- name=result.name,
- namespace=result.namespace,
- partitions=self._partitions,
- )
- def _unary(self, func, do_func):
- # noinspection PyProtectedMember
- results = self._session._submit_unary(
- func, do_func, self._partitions, self._name, self._namespace
- )
- result = results[0]
- # noinspection PyProtectedMember
- return _create_table(
- session=self._session,
- name=result.name,
- namespace=result.namespace,
- partitions=self._partitions,
- )
- def _binary(self, other: "Table", func, do_func):
- session_id = self._session.session_id
- left, right = self, other
- if left._partitions != right._partitions:
- if other.count() > self.count():
- left = left.save_as(
- str(uuid.uuid1()), session_id, partition=right._partitions
- )
- else:
- right = other.save_as(
- str(uuid.uuid1()), session_id, partition=left._partitions
- )
- # noinspection PyProtectedMember
- results = self._session._submit_binary(
- func,
- do_func,
- left._partitions,
- left._name,
- left._namespace,
- right._name,
- right._namespace,
- )
- result: _Operand = results[0]
- # noinspection PyProtectedMember
- return _create_table(
- session=self._session,
- name=result.name,
- namespace=result.namespace,
- partitions=left._partitions,
- )
- def save_as(self, name, namespace, partition=None, need_cleanup=True):
- if partition is None:
- partition = self._partitions
- # noinspection PyProtectedMember
- dup = _create_table(self._session, name, namespace, partition, need_cleanup)
- dup.put_all(self.collect())
- return dup
- def _get_env_for_partition(self, p: int, write=False):
- return _get_env(self._namespace, self._name, str(p), write=write)
- def put(self, k, v):
- k_bytes, v_bytes = _kv_to_bytes(k=k, v=v)
- p = _hash_key_to_partition(k_bytes, self._partitions)
- with self._get_env_for_partition(p, write=True) as env:
- with env.begin(write=True) as txn:
- return txn.put(k_bytes, v_bytes)
- def put_all(self, kv_list: Iterable):
- txn_map = {}
- is_success = True
- with ExitStack() as s:
- for p in range(self._partitions):
- env = s.enter_context(self._get_env_for_partition(p, write=True))
- txn_map[p] = env, env.begin(write=True)
- for k, v in kv_list:
- try:
- k_bytes, v_bytes = _kv_to_bytes(k=k, v=v)
- p = _hash_key_to_partition(k_bytes, self._partitions)
- is_success = is_success and txn_map[p][1].put(k_bytes, v_bytes)
- except Exception as e:
- is_success = False
- LOGGER.exception(f"put_all for k={k} v={v} fail. exception: {e}")
- break
- for p, (env, txn) in txn_map.items():
- txn.commit() if is_success else txn.abort()
- def get(self, k):
- k_bytes = _k_to_bytes(k=k)
- p = _hash_key_to_partition(k_bytes, self._partitions)
- with self._get_env_for_partition(p) as env:
- with env.begin(write=True) as txn:
- old_value_bytes = txn.get(k_bytes)
- return (
- None if old_value_bytes is None else deserialize(old_value_bytes)
- )
- def delete(self, k):
- k_bytes = _k_to_bytes(k=k)
- p = _hash_key_to_partition(k_bytes, self._partitions)
- with self._get_env_for_partition(p, write=True) as env:
- with env.begin(write=True) as txn:
- old_value_bytes = txn.get(k_bytes)
- if txn.delete(k_bytes):
- return (
- None
- if old_value_bytes is None
- else deserialize(old_value_bytes)
- )
- return None
- # noinspection PyMethodMayBeStatic
- class Session(object):
- def __init__(self, session_id, max_workers=None):
- self.session_id = session_id
- self._pool = Executor(max_workers=max_workers)
- def __getstate__(self):
- # session won't be pickled
- pass
- def load(self, name, namespace):
- return _load_table(session=self, name=name, namespace=namespace)
- def create_table(self, name, namespace, partitions, need_cleanup, error_if_exist):
- return _create_table(
- session=self,
- name=name,
- namespace=namespace,
- partitions=partitions,
- need_cleanup=need_cleanup,
- error_if_exist=error_if_exist,
- )
- # noinspection PyUnusedLocal
- def parallelize(
- self, data: Iterable, partition: int, include_key: bool = False, **kwargs
- ):
- if not include_key:
- data = enumerate(data)
- table = _create_table(
- session=self,
- name=str(uuid.uuid1()),
- namespace=self.session_id,
- partitions=partition,
- )
- table.put_all(data)
- return table
- def cleanup(self, name, namespace):
- data_path = _get_data_dir()
- if not data_path.is_dir():
- LOGGER.error(f"illegal data dir: {data_path}")
- return
- namespace_dir = data_path.joinpath(namespace)
- if not namespace_dir.is_dir():
- return
- if name == "*":
- shutil.rmtree(namespace_dir, True)
- return
- for table in namespace_dir.glob(name):
- shutil.rmtree(table, True)
- def stop(self):
- self.cleanup(name="*", namespace=self.session_id)
- self._pool.shutdown()
- def kill(self):
- self.cleanup(name="*", namespace=self.session_id)
- self._pool.shutdown()
- def _submit_unary(self, func, _do_func, partitions, name, namespace):
- task_info = _TaskInfo(
- self.session_id,
- function_id=str(uuid.uuid1()),
- function_bytes=f_pickle.dumps(func),
- )
- futures = []
- for p in range(partitions):
- futures.append(
- self._pool.submit(
- _do_func, _UnaryProcess(task_info, _Operand(namespace, name, p))
- )
- )
- results = [r.result() for r in futures]
- return results
- def _submit_map_reduce_in_partition(
- self, mapper, reducer, partitions, name, namespace
- ):
- task_info = _MapReduceTaskInfo(
- self.session_id,
- function_id=str(uuid.uuid1()),
- map_function_bytes=f_pickle.dumps(mapper),
- reduce_function_bytes=f_pickle.dumps(reducer),
- )
- futures = []
- for p in range(partitions):
- futures.append(
- self._pool.submit(
- _do_map_reduce_in_partitions,
- _MapReduceProcess(task_info, _Operand(namespace, name, p)),
- )
- )
- results = [r.result() for r in futures]
- return results
- def _submit_binary(
- self, func, do_func, partitions, name, namespace, other_name, other_namespace
- ):
- task_info = _TaskInfo(
- self.session_id,
- function_id=str(uuid.uuid1()),
- function_bytes=f_pickle.dumps(func),
- )
- futures = []
- for p in range(partitions):
- left = _Operand(namespace, name, p)
- right = _Operand(other_namespace, other_name, p)
- futures.append(
- self._pool.submit(do_func, _BinaryProcess(task_info, left, right))
- )
- results = [r.result() for r in futures]
- return results
- def _get_splits(obj, max_message_size):
- obj_bytes = serialize(obj, protocol=4)
- byte_size = len(obj_bytes)
- num_slice = (byte_size - 1) // max_message_size + 1
- if num_slice <= 1:
- return obj, num_slice
- else:
- _max_size = max_message_size
- kv = [(i, obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) for i in range(num_slice)]
- return kv, num_slice
- class Federation(object):
- def _federation_object_key(self, name, tag, s_party, d_party):
- return f"{self._session_id}-{name}-{tag}-{s_party.role}-{s_party.party_id}-{d_party.role}-{d_party.party_id}"
- def __init__(self, session: Session, session_id, party: Party):
- self._session_id = session_id
- self._party: Party = party
- self._session = session
- self._max_message_size = DEFAULT_MESSAGE_MAX_SIZE
- self._other_status_tables = {}
- self._other_object_tables = {}
- self._even_loop = None
- self._federation_status_table_cache = None
- self._federation_object_table_cache = None
- def destroy(self):
- self._session.cleanup(namespace=self._session_id, name="*")
- @property
- def _federation_status_table(self):
- if self._federation_status_table_cache is None:
- self._federation_status_table_cache = _create_table(
- session=self._session,
- name=self._get_status_table_name(self._party),
- namespace=self._session_id,
- partitions=1,
- need_cleanup=True,
- error_if_exist=False,
- )
- return self._federation_status_table_cache
- @property
- def _federation_object_table(self):
- if self._federation_object_table_cache is None:
- self._federation_object_table_cache = _create_table(
- session=self._session,
- name=self._get_object_table_name(self._party),
- namespace=self._session_id,
- partitions=1,
- need_cleanup=True,
- error_if_exist=False,
- )
- return self._federation_object_table_cache
- @property
- def _loop(self):
- if self._even_loop is None:
- self._even_loop = asyncio.get_event_loop()
- return self._even_loop
- @staticmethod
- def _get_status_table_name(party):
- return f"__federation_status__.{party.role}_{party.party_id}"
- @staticmethod
- def _get_object_table_name(party):
- return f"__federation_object__.{party.role}_{party.party_id}"
- def _get_other_status_table(self, party):
- if party in self._other_status_tables:
- return self._other_status_tables[party]
- table = _create_table(
- self._session,
- name=self._get_status_table_name(party),
- namespace=self._session_id,
- partitions=1,
- need_cleanup=False,
- error_if_exist=False,
- )
- self._other_status_tables[party] = table
- return table
- def _get_other_object_table(self, party):
- if party in self._other_object_tables:
- return self._other_object_tables[party]
- table = _create_table(
- self._session,
- name=self._get_object_table_name(party),
- namespace=self._session_id,
- partitions=1,
- need_cleanup=False,
- error_if_exist=False,
- )
- self._other_object_tables[party] = table
- return table
- # noinspection PyProtectedMember
- def _put_status(self, party, _tagged_key, value):
- self._get_other_status_table(party).put(_tagged_key, value)
- # noinspection PyProtectedMember
- def _put_object(self, party, _tagged_key, value):
- self._get_other_object_table(party).put(_tagged_key, value)
- # noinspection PyProtectedMember
- def _get_object(self, _tagged_key):
- return self._federation_object_table.get(_tagged_key)
- # noinspection PyProtectedMember
- def _get_status(self, _tagged_key):
- return self._federation_status_table.get(_tagged_key)
- # noinspection PyUnusedLocal
- def remote(self, v, name: str, tag: str, parties: typing.List[Party]):
- log_str = f"federation.standalone.remote.{name}.{tag}"
- if v is None:
- raise ValueError(f"[{log_str}]remote `None` to {parties}")
- LOGGER.debug(f"[{log_str}]remote data, type={type(v)}")
- if isinstance(v, Table):
- dtype = FederationDataType.TABLE
- LOGGER.debug(
- f"[{log_str}]remote "
- f"Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}), dtype={dtype}"
- )
- else:
- v_splits, num_slice = _get_splits(v, self._max_message_size)
- if num_slice > 1:
- v = _create_table(
- session=self._session,
- name=str(uuid.uuid1()),
- namespace=self._session_id,
- partitions=1,
- need_cleanup=True,
- error_if_exist=False,
- )
- v.put_all(kv_list=v_splits)
- dtype = FederationDataType.SPLIT_OBJECT
- LOGGER.debug(
- f"[{log_str}]remote "
- f"Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}), dtype={dtype}"
- )
- else:
- LOGGER.debug(f"[{log_str}]remote object with type: {type(v)}")
- dtype = FederationDataType.OBJECT
- for party in parties:
- _tagged_key = self._federation_object_key(name, tag, self._party, party)
- if isinstance(v, Table):
- saved_name = str(uuid.uuid1())
- LOGGER.debug(
- f"[{log_str}]save Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}) as "
- f"Table(namespace={v.namespace}, name={saved_name}, partitions={v.partitions})"
- )
- _v = v.save_as(
- name=saved_name, namespace=v.namespace, need_cleanup=False
- )
- self._put_status(party, _tagged_key, (_v.name, _v.namespace, dtype))
- else:
- self._put_object(party, _tagged_key, v)
- self._put_status(party, _tagged_key, _tagged_key)
- # noinspection PyProtectedMember
- def get(self, name: str, tag: str, parties: typing.List[Party]) -> typing.List:
- log_str = f"federation.standalone.get.{name}.{tag}"
- LOGGER.debug(f"[{log_str}]")
- tasks = []
- for party in parties:
- _tagged_key = self._federation_object_key(name, tag, party, self._party)
- tasks.append(_check_status_and_get_value(self._get_status, _tagged_key))
- results = self._loop.run_until_complete(asyncio.gather(*tasks))
- rtn = []
- for r in results:
- if isinstance(r, tuple):
- # noinspection PyTypeChecker
- table: Table = _load_table(
- session=self._session, name=r[0], namespace=r[1], need_cleanup=True
- )
- dtype = r[2]
- LOGGER.debug(
- f"[{log_str}] got "
- f"Table(namespace={table.namespace}, name={table.name}, partitions={table.partitions}), dtype={dtype}")
- if dtype == FederationDataType.SPLIT_OBJECT:
- obj_bytes = b''.join(map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0])))
- obj = deserialize(obj_bytes)
- rtn.append(obj)
- else:
- rtn.append(table)
- else:
- obj = self._get_object(r)
- if obj is None:
- raise EnvironmentError(
- f"federation get None from {parties} with name {name}, tag {tag}"
- )
- rtn.append(obj)
- self._federation_object_table.delete(k=r)
- LOGGER.debug(f"[{log_str}] got object with type: {type(obj)}")
- self._federation_status_table.delete(r)
- return rtn
- _meta_table: typing.Optional[Table] = None
- _SESSION = Session(uuid.uuid1().hex)
- def _get_meta_table():
- global _meta_table
- if _meta_table is None:
- _meta_table = Table(
- _SESSION,
- namespace="__META__",
- name="fragments",
- partitions=10,
- need_cleanup=False,
- )
- return _meta_table
- # noinspection PyProtectedMember
- def _get_from_meta_table(key):
- return _get_meta_table().get(key)
- # noinspection PyProtectedMember
- def _put_to_meta_table(key, value):
- _get_meta_table().put(key, value)
- _data_dir = Path(file_utils.get_project_base_directory()).joinpath("data").absolute()
- def _get_data_dir():
- return _data_dir
- def _get_storage_dir(*args):
- return _data_dir.joinpath(*args)
- async def _check_status_and_get_value(get_func, key):
- value = get_func(key)
- while value is None:
- await asyncio.sleep(0.1)
- value = get_func(key)
- LOGGER.debug(
- "[GET] Got {} type {}".format(
- key, "Table" if isinstance(value, tuple) else "Object"
- )
- )
- return value
- def _create_table(
- session: "Session",
- name: str,
- namespace: str,
- partitions: int,
- need_cleanup=True,
- error_if_exist=False,
- ):
- if isinstance(namespace, int):
- raise ValueError(f"{namespace} {name}")
- _table_key = ".".join([namespace, name])
- if _get_from_meta_table(_table_key) is not None:
- if error_if_exist:
- raise RuntimeError(
- f"table already exist: name={name}, namespace={namespace}"
- )
- else:
- partitions = _get_from_meta_table(_table_key)
- else:
- _put_to_meta_table(_table_key, partitions)
- return Table(
- session=session,
- namespace=namespace,
- name=name,
- partitions=partitions,
- need_cleanup=need_cleanup,
- )
- def _exist(name: str, namespace: str):
- _table_key = ".".join([namespace, name])
- return _get_from_meta_table(_table_key) is not None
- def _load_table(session, name, namespace, need_cleanup=False):
- _table_key = ".".join([namespace, name])
- partitions = _get_from_meta_table(_table_key)
- if partitions is None:
- raise RuntimeError(f"table not exist: name={name}, namespace={namespace}")
- return Table(
- session=session,
- namespace=namespace,
- name=name,
- partitions=partitions,
- need_cleanup=need_cleanup,
- )
- class _TaskInfo:
- def __init__(self, task_id, function_id, function_bytes):
- self.task_id = task_id
- self.function_id = function_id
- self.function_bytes = function_bytes
- self._function_deserialized = None
- def get_func(self):
- if self._function_deserialized is None:
- self._function_deserialized = f_pickle.loads(self.function_bytes)
- return self._function_deserialized
- class _MapReduceTaskInfo:
- def __init__(self, task_id, function_id, map_function_bytes, reduce_function_bytes):
- self.task_id = task_id
- self.function_id = function_id
- self.map_function_bytes = map_function_bytes
- self.reduce_function_bytes = reduce_function_bytes
- self._reduce_function_deserialized = None
- self._mapper_function_deserialized = None
- def get_mapper(self):
- if self._mapper_function_deserialized is None:
- self._mapper_function_deserialized = f_pickle.loads(self.map_function_bytes)
- return self._mapper_function_deserialized
- def get_reducer(self):
- if self._reduce_function_deserialized is None:
- self._reduce_function_deserialized = f_pickle.loads(
- self.reduce_function_bytes
- )
- return self._reduce_function_deserialized
- class _Operand:
- def __init__(self, namespace, name, partition):
- self.namespace = namespace
- self.name = name
- self.partition = partition
- def as_env(self, write=False):
- return _get_env(self.namespace, self.name, str(self.partition), write=write)
- class _UnaryProcess:
- def __init__(self, task_info: _TaskInfo, operand: _Operand):
- self.info = task_info
- self.operand = operand
- def output_operand(self):
- return _Operand(
- self.info.task_id, self.info.function_id, self.operand.partition
- )
- def get_func(self):
- return self.info.get_func()
- class _MapReduceProcess:
- def __init__(self, task_info: _MapReduceTaskInfo, operand: _Operand):
- self.info = task_info
- self.operand = operand
- def output_operand(self):
- return _Operand(
- self.info.task_id, self.info.function_id, self.operand.partition
- )
- def get_mapper(self):
- return self.info.get_mapper()
- def get_reducer(self):
- return self.info.get_reducer()
- class _BinaryProcess:
- def __init__(self, task_info: _TaskInfo, left: _Operand, right: _Operand):
- self.info = task_info
- self.left = left
- self.right = right
- def output_operand(self):
- return _Operand(self.info.task_id, self.info.function_id, self.left.partition)
- def get_func(self):
- return self.info.get_func()
- def _get_env(*args, write=False):
- _path = _get_storage_dir(*args)
- return _open_env(_path, write=write)
- def _open_env(path, write=False):
- path.mkdir(parents=True, exist_ok=True)
- t = 0
- while t < 100:
- try:
- env = lmdb.open(
- path.as_posix(),
- create=True,
- max_dbs=1,
- max_readers=1024,
- lock=write,
- sync=True,
- map_size=10_737_418_240,
- )
- return env
- except lmdb.Error as e:
- if "No such file or directory" in e.args[0]:
- time.sleep(0.01)
- t += 1
- else:
- raise e
- raise lmdb.Error(f"No such file or directory: {path}, with {t} times retry")
- def _hash_key_to_partition(key, partitions):
- _key = hashlib.sha1(key).digest()
- if isinstance(_key, bytes):
- _key = int.from_bytes(_key, byteorder="little", signed=False)
- if partitions < 1:
- raise ValueError("partitions must be a positive number")
- b, j = -1, 0
- while j < partitions:
- b = int(j)
- _key = ((_key * 2862933555777941757) + 1) & 0xFFFFFFFFFFFFFFFF
- j = float(b + 1) * (float(1 << 31) / float((_key >> 33) + 1))
- return int(b)
- def _do_map(p: _UnaryProcess):
- rtn = p.output_operand()
- with ExitStack() as s:
- source_env = s.enter_context(p.operand.as_env())
- partitions = _get_from_meta_table(f"{p.operand.namespace}.{p.operand.name}")
- txn_map = {}
- for partition in range(partitions):
- env = s.enter_context(
- _get_env(rtn.namespace, rtn.name, str(partition), write=True)
- )
- txn_map[partition] = s.enter_context(env.begin(write=True))
- source_txn = s.enter_context(source_env.begin())
- cursor = s.enter_context(source_txn.cursor())
- for k_bytes, v_bytes in cursor:
- k, v = deserialize(k_bytes), deserialize(v_bytes)
- k1, v1 = p.get_func()(k, v)
- k1_bytes, v1_bytes = serialize(k1), serialize(v1)
- partition = _hash_key_to_partition(k1_bytes, partitions)
- txn_map[partition].put(k1_bytes, v1_bytes)
- return rtn
- def _generator_from_cursor(cursor):
- for k, v in cursor:
- yield deserialize(k), deserialize(v)
- def _do_apply_partitions(p: _UnaryProcess):
- with ExitStack() as s:
- rtn = p.output_operand()
- source_env = s.enter_context(p.operand.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- source_txn = s.enter_context(source_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(source_txn.cursor())
- v = p.get_func()(_generator_from_cursor(cursor))
- if cursor.last():
- k_bytes = cursor.key()
- dst_txn.put(k_bytes, serialize(v))
- return rtn
- def _do_map_partitions(p: _UnaryProcess):
- with ExitStack() as s:
- rtn = p.output_operand()
- source_env = s.enter_context(p.operand.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- source_txn = s.enter_context(source_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(source_txn.cursor())
- v = p.get_func()(_generator_from_cursor(cursor))
- if isinstance(v, Iterable):
- for k1, v1 in v:
- dst_txn.put(serialize(k1), serialize(v1))
- else:
- k_bytes = cursor.key()
- dst_txn.put(k_bytes, serialize(v))
- return rtn
- def _do_map_partitions_with_index(p: _UnaryProcess):
- with ExitStack() as s:
- rtn = p.output_operand()
- source_env = s.enter_context(p.operand.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- source_txn = s.enter_context(source_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(source_txn.cursor())
- v = p.get_func()(p.operand.partition, _generator_from_cursor(cursor))
- if isinstance(v, Iterable):
- for k1, v1 in v:
- dst_txn.put(serialize(k1), serialize(v1))
- else:
- k_bytes = cursor.key()
- dst_txn.put(k_bytes, serialize(v))
- return rtn
- def _do_map_reduce_in_partitions(p: _MapReduceProcess):
- rtn = p.output_operand()
- with ExitStack() as s:
- source_env = s.enter_context(p.operand.as_env())
- partitions = _get_from_meta_table(f"{p.operand.namespace}.{p.operand.name}")
- txn_map = {}
- for partition in range(partitions):
- env = s.enter_context(
- _get_env(rtn.namespace, rtn.name, str(partition), write=True)
- )
- txn_map[partition] = s.enter_context(env.begin(write=True))
- source_txn = s.enter_context(source_env.begin())
- cursor = s.enter_context(source_txn.cursor())
- mapped = p.get_mapper()(_generator_from_cursor(cursor))
- if not isinstance(mapped, Iterable):
- raise ValueError("mapper function should return a iterable of pair")
- reducer = p.get_reducer()
- for k, v in mapped:
- k_bytes = serialize(k)
- partition = _hash_key_to_partition(k_bytes, partitions)
- # todo: not atomic, fix me
- pre_v = txn_map[partition].get(k_bytes, None)
- if pre_v is None:
- txn_map[partition].put(k_bytes, serialize(v))
- else:
- txn_map[partition].put(
- k_bytes, serialize(reducer(deserialize(pre_v), v))
- )
- return rtn
- def _do_map_values(p: _UnaryProcess):
- rtn = p.output_operand()
- with ExitStack() as s:
- source_env = s.enter_context(p.operand.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- source_txn = s.enter_context(source_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(source_txn.cursor())
- for k_bytes, v_bytes in cursor:
- v = deserialize(v_bytes)
- v1 = p.get_func()(v)
- dst_txn.put(k_bytes, serialize(v1))
- return rtn
- def _do_flat_map(p: _UnaryProcess):
- rtn = p.output_operand()
- with ExitStack() as s:
- source_env = s.enter_context(p.operand.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- source_txn = s.enter_context(source_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(source_txn.cursor())
- for k_bytes, v_bytes in cursor:
- k = deserialize(k_bytes)
- v = deserialize(v_bytes)
- map_result = p.get_func()(k, v)
- for result_k, result_v in map_result:
- dst_txn.put(serialize(result_k), serialize(result_v))
- return rtn
- def _do_reduce(p: _UnaryProcess):
- value = None
- with ExitStack() as s:
- source_env = s.enter_context(p.operand.as_env())
- source_txn = s.enter_context(source_env.begin())
- cursor = s.enter_context(source_txn.cursor())
- for k_bytes, v_bytes in cursor:
- v = deserialize(v_bytes)
- if value is None:
- value = v
- else:
- value = p.get_func()(value, v)
- return value
- def _do_glom(p: _UnaryProcess):
- rtn = p.output_operand()
- with ExitStack() as s:
- source_env = s.enter_context(p.operand.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- source_txn = s.enter_context(source_env.begin())
- dest_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(source_txn.cursor())
- v_list = []
- k_bytes = None
- for k, v in cursor:
- v_list.append((deserialize(k), deserialize(v)))
- k_bytes = k
- if k_bytes is not None:
- dest_txn.put(k_bytes, serialize(v_list))
- return rtn
- def _do_sample(p: _UnaryProcess):
- rtn = p.output_operand()
- fraction, seed = deserialize(p.info.function_bytes)
- with ExitStack() as s:
- source_env = s.enter_context(p.operand.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- source_txn = s.enter_context(source_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(source_txn.cursor())
- cursor.first()
- random_state = np.random.RandomState(seed)
- for k, v in cursor:
- # noinspection PyArgumentList
- if random_state.rand() < fraction:
- dst_txn.put(k, v)
- return rtn
- def _do_filter(p: _UnaryProcess):
- rtn = p.output_operand()
- with ExitStack() as s:
- source_env = s.enter_context(p.operand.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- source_txn = s.enter_context(source_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(source_txn.cursor())
- for k_bytes, v_bytes in cursor:
- k = deserialize(k_bytes)
- v = deserialize(v_bytes)
- if p.get_func()(k, v):
- dst_txn.put(k_bytes, v_bytes)
- return rtn
- def _do_subtract_by_key(p: _BinaryProcess):
- rtn = p.output_operand()
- with ExitStack() as s:
- left_op = p.left
- right_op = p.right
- right_env = s.enter_context(right_op.as_env())
- left_env = s.enter_context(left_op.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- left_txn = s.enter_context(left_env.begin())
- right_txn = s.enter_context(right_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(left_txn.cursor())
- for k_bytes, left_v_bytes in cursor:
- right_v_bytes = right_txn.get(k_bytes)
- if right_v_bytes is None:
- dst_txn.put(k_bytes, left_v_bytes)
- return rtn
- def _do_join(p: _BinaryProcess):
- rtn = p.output_operand()
- with ExitStack() as s:
- right_env = s.enter_context(p.right.as_env())
- left_env = s.enter_context(p.left.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- left_txn = s.enter_context(left_env.begin())
- right_txn = s.enter_context(right_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- cursor = s.enter_context(left_txn.cursor())
- for k_bytes, v1_bytes in cursor:
- v2_bytes = right_txn.get(k_bytes)
- if v2_bytes is None:
- continue
- v1 = deserialize(v1_bytes)
- v2 = deserialize(v2_bytes)
- v3 = p.get_func()(v1, v2)
- dst_txn.put(k_bytes, serialize(v3))
- return rtn
- def _do_union(p: _BinaryProcess):
- rtn = p.output_operand()
- with ExitStack() as s:
- left_env = s.enter_context(p.left.as_env())
- right_env = s.enter_context(p.right.as_env())
- dst_env = s.enter_context(rtn.as_env(write=True))
- left_txn = s.enter_context(left_env.begin())
- right_txn = s.enter_context(right_env.begin())
- dst_txn = s.enter_context(dst_env.begin(write=True))
- # process left op
- with left_txn.cursor() as left_cursor:
- for k_bytes, left_v_bytes in left_cursor:
- right_v_bytes = right_txn.get(k_bytes)
- if right_v_bytes is None:
- dst_txn.put(k_bytes, left_v_bytes)
- else:
- left_v = deserialize(left_v_bytes)
- right_v = deserialize(right_v_bytes)
- final_v = p.get_func()(left_v, right_v)
- dst_txn.put(k_bytes, serialize(final_v))
- # process right op
- with right_txn.cursor() as right_cursor:
- for k_bytes, right_v_bytes in right_cursor:
- final_v_bytes = dst_txn.get(k_bytes)
- if final_v_bytes is None:
- dst_txn.put(k_bytes, right_v_bytes)
- return rtn
- def _kv_to_bytes(k, v):
- return serialize(k), serialize(v)
- def _k_to_bytes(k):
- return serialize(k)
|