#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import hashlib
import itertools
import pickle as c_pickle
import shutil
import time
import typing
import uuid
from collections import Iterable
from concurrent.futures import ProcessPoolExecutor as Executor
from contextlib import ExitStack
from functools import partial
from heapq import heapify, heappop, heapreplace
from operator import is_not
from pathlib import Path
import cloudpickle as f_pickle
import lmdb
import numpy as np
from fate_arch.common import Party, file_utils
from fate_arch.common.log import getLogger
from fate_arch.federation import FederationDataType
LOGGER = getLogger()
serialize = c_pickle.dumps
deserialize = c_pickle.loads
# default message max size in bytes = 1MB
DEFAULT_MESSAGE_MAX_SIZE = 1048576
# noinspection PyPep8Naming
class Table(object):
def __init__(
self,
session: "Session",
namespace: str,
name: str,
partitions,
need_cleanup=True,
):
self._need_cleanup = need_cleanup
self._namespace = namespace
self._name = name
self._partitions = partitions
self._session = session
@property
def partitions(self):
return self._partitions
@property
def name(self):
return self._name
@property
def namespace(self):
return self._namespace
def __del__(self):
if self._need_cleanup:
self.destroy()
def __str__(self):
return f"
"
def __repr__(self):
return self.__str__()
def destroy(self):
for p in range(self._partitions):
with self._get_env_for_partition(p, write=True) as env:
db = env.open_db()
with env.begin(write=True) as txn:
txn.drop(db)
table_key = f"{self._namespace}.{self._name}"
_get_meta_table().delete(table_key)
path = _get_storage_dir(self._namespace, self._name)
shutil.rmtree(path, ignore_errors=True)
def take(self, n, **kwargs):
if n <= 0:
raise ValueError(f"{n} <= 0")
return list(itertools.islice(self.collect(**kwargs), n))
def count(self):
cnt = 0
for p in range(self._partitions):
with self._get_env_for_partition(p) as env:
cnt += env.stat()["entries"]
return cnt
# noinspection PyUnusedLocal
def collect(self, **kwargs):
iterators = []
with ExitStack() as s:
for p in range(self._partitions):
env = s.enter_context(self._get_env_for_partition(p))
txn = s.enter_context(env.begin())
iterators.append(s.enter_context(txn.cursor()))
# Merge sorted
entries = []
for _id, it in enumerate(iterators):
if it.next():
key, value = it.item()
entries.append([key, value, _id, it])
heapify(entries)
while entries:
key, value, _, it = entry = entries[0]
yield deserialize(key), deserialize(value)
if it.next():
entry[0], entry[1] = it.item()
heapreplace(entries, entry)
else:
_, _, _, it = heappop(entries)
def reduce(self, func):
# noinspection PyProtectedMember
rs = self._session._submit_unary(
func, _do_reduce, self._partitions, self._name, self._namespace
)
rs = [r for r in filter(partial(is_not, None), rs)]
if len(rs) <= 0:
return None
rtn = rs[0]
for r in rs[1:]:
rtn = func(rtn, r)
return rtn
def map(self, func):
return self._unary(func, _do_map)
def mapValues(self, func):
return self._unary(func, _do_map_values)
def flatMap(self, func):
_flat_mapped = self._unary(func, _do_flat_map)
return _flat_mapped.save_as(
name=str(uuid.uuid1()),
namespace=_flat_mapped.namespace,
partition=self._partitions,
need_cleanup=True,
)
def applyPartitions(self, func):
return self._unary(func, _do_apply_partitions)
def mapPartitions(self, func, preserves_partitioning=False):
un_shuffled = self._unary(func, _do_map_partitions)
if preserves_partitioning:
return un_shuffled
return un_shuffled.save_as(
name=str(uuid.uuid1()),
namespace=un_shuffled.namespace,
partition=self._partitions,
need_cleanup=True,
)
def mapPartitionsWithIndex(self, func, preserves_partitioning=False):
un_shuffled = self._unary(func, _do_map_partitions_with_index)
if preserves_partitioning:
return un_shuffled
return un_shuffled.save_as(
name=str(uuid.uuid1()),
namespace=un_shuffled.namespace,
partition=self._partitions,
need_cleanup=True,
)
def mapReducePartitions(self, mapper, reducer):
dup = _create_table(
self._session,
str(uuid.uuid1()),
self.namespace,
self._partitions,
need_cleanup=True,
)
def _dict_reduce(a: dict, b: dict):
for k, v in b.items():
if k not in a:
a[k] = v
else:
a[k] = reducer(a[k], v)
return a
def _local_map_reduce(it):
ret = {}
for _k, _v in mapper(it):
if _k not in ret:
ret[_k] = _v
else:
ret[_k] = reducer(ret[_k], _v)
return ret
dup.put_all(
self.applyPartitions(_local_map_reduce).reduce(_dict_reduce).items()
)
return dup
def glom(self):
return self._unary(None, _do_glom)
def sample(self, fraction, seed=None):
return self._unary((fraction, seed), _do_sample)
def filter(self, func):
return self._unary(func, _do_filter)
def join(self, other: "Table", func):
return self._binary(other, func, _do_join)
def subtractByKey(self, other: "Table"):
func = f"{self._namespace}.{self._name}-{other._namespace}.{other._name}"
return self._binary(other, func, _do_subtract_by_key)
def union(self, other: "Table", func=lambda v1, v2: v1):
return self._binary(other, func, _do_union)
# noinspection PyProtectedMember
def _map_reduce(self, mapper, reducer):
results = self._session._submit_map_reduce(
mapper, reducer, self._partitions, self._name, self._namespace
)
result = results[0]
# noinspection PyProtectedMember
return _create_table(
session=self._session,
name=result.name,
namespace=result.namespace,
partitions=self._partitions,
)
def _unary(self, func, do_func):
# noinspection PyProtectedMember
results = self._session._submit_unary(
func, do_func, self._partitions, self._name, self._namespace
)
result = results[0]
# noinspection PyProtectedMember
return _create_table(
session=self._session,
name=result.name,
namespace=result.namespace,
partitions=self._partitions,
)
def _binary(self, other: "Table", func, do_func):
session_id = self._session.session_id
left, right = self, other
if left._partitions != right._partitions:
if other.count() > self.count():
left = left.save_as(
str(uuid.uuid1()), session_id, partition=right._partitions
)
else:
right = other.save_as(
str(uuid.uuid1()), session_id, partition=left._partitions
)
# noinspection PyProtectedMember
results = self._session._submit_binary(
func,
do_func,
left._partitions,
left._name,
left._namespace,
right._name,
right._namespace,
)
result: _Operand = results[0]
# noinspection PyProtectedMember
return _create_table(
session=self._session,
name=result.name,
namespace=result.namespace,
partitions=left._partitions,
)
def save_as(self, name, namespace, partition=None, need_cleanup=True):
if partition is None:
partition = self._partitions
# noinspection PyProtectedMember
dup = _create_table(self._session, name, namespace, partition, need_cleanup)
dup.put_all(self.collect())
return dup
def _get_env_for_partition(self, p: int, write=False):
return _get_env(self._namespace, self._name, str(p), write=write)
def put(self, k, v):
k_bytes, v_bytes = _kv_to_bytes(k=k, v=v)
p = _hash_key_to_partition(k_bytes, self._partitions)
with self._get_env_for_partition(p, write=True) as env:
with env.begin(write=True) as txn:
return txn.put(k_bytes, v_bytes)
def put_all(self, kv_list: Iterable):
txn_map = {}
is_success = True
with ExitStack() as s:
for p in range(self._partitions):
env = s.enter_context(self._get_env_for_partition(p, write=True))
txn_map[p] = env, env.begin(write=True)
for k, v in kv_list:
try:
k_bytes, v_bytes = _kv_to_bytes(k=k, v=v)
p = _hash_key_to_partition(k_bytes, self._partitions)
is_success = is_success and txn_map[p][1].put(k_bytes, v_bytes)
except Exception as e:
is_success = False
LOGGER.exception(f"put_all for k={k} v={v} fail. exception: {e}")
break
for p, (env, txn) in txn_map.items():
txn.commit() if is_success else txn.abort()
def get(self, k):
k_bytes = _k_to_bytes(k=k)
p = _hash_key_to_partition(k_bytes, self._partitions)
with self._get_env_for_partition(p) as env:
with env.begin(write=True) as txn:
old_value_bytes = txn.get(k_bytes)
return (
None if old_value_bytes is None else deserialize(old_value_bytes)
)
def delete(self, k):
k_bytes = _k_to_bytes(k=k)
p = _hash_key_to_partition(k_bytes, self._partitions)
with self._get_env_for_partition(p, write=True) as env:
with env.begin(write=True) as txn:
old_value_bytes = txn.get(k_bytes)
if txn.delete(k_bytes):
return (
None
if old_value_bytes is None
else deserialize(old_value_bytes)
)
return None
# noinspection PyMethodMayBeStatic
class Session(object):
def __init__(self, session_id, max_workers=None):
self.session_id = session_id
self._pool = Executor(max_workers=max_workers)
def __getstate__(self):
# session won't be pickled
pass
def load(self, name, namespace):
return _load_table(session=self, name=name, namespace=namespace)
def create_table(self, name, namespace, partitions, need_cleanup, error_if_exist):
return _create_table(
session=self,
name=name,
namespace=namespace,
partitions=partitions,
need_cleanup=need_cleanup,
error_if_exist=error_if_exist,
)
# noinspection PyUnusedLocal
def parallelize(
self, data: Iterable, partition: int, include_key: bool = False, **kwargs
):
if not include_key:
data = enumerate(data)
table = _create_table(
session=self,
name=str(uuid.uuid1()),
namespace=self.session_id,
partitions=partition,
)
table.put_all(data)
return table
def cleanup(self, name, namespace):
data_path = _get_data_dir()
if not data_path.is_dir():
LOGGER.error(f"illegal data dir: {data_path}")
return
namespace_dir = data_path.joinpath(namespace)
if not namespace_dir.is_dir():
return
if name == "*":
shutil.rmtree(namespace_dir, True)
return
for table in namespace_dir.glob(name):
shutil.rmtree(table, True)
def stop(self):
self.cleanup(name="*", namespace=self.session_id)
self._pool.shutdown()
def kill(self):
self.cleanup(name="*", namespace=self.session_id)
self._pool.shutdown()
def _submit_unary(self, func, _do_func, partitions, name, namespace):
task_info = _TaskInfo(
self.session_id,
function_id=str(uuid.uuid1()),
function_bytes=f_pickle.dumps(func),
)
futures = []
for p in range(partitions):
futures.append(
self._pool.submit(
_do_func, _UnaryProcess(task_info, _Operand(namespace, name, p))
)
)
results = [r.result() for r in futures]
return results
def _submit_map_reduce_in_partition(
self, mapper, reducer, partitions, name, namespace
):
task_info = _MapReduceTaskInfo(
self.session_id,
function_id=str(uuid.uuid1()),
map_function_bytes=f_pickle.dumps(mapper),
reduce_function_bytes=f_pickle.dumps(reducer),
)
futures = []
for p in range(partitions):
futures.append(
self._pool.submit(
_do_map_reduce_in_partitions,
_MapReduceProcess(task_info, _Operand(namespace, name, p)),
)
)
results = [r.result() for r in futures]
return results
def _submit_binary(
self, func, do_func, partitions, name, namespace, other_name, other_namespace
):
task_info = _TaskInfo(
self.session_id,
function_id=str(uuid.uuid1()),
function_bytes=f_pickle.dumps(func),
)
futures = []
for p in range(partitions):
left = _Operand(namespace, name, p)
right = _Operand(other_namespace, other_name, p)
futures.append(
self._pool.submit(do_func, _BinaryProcess(task_info, left, right))
)
results = [r.result() for r in futures]
return results
def _get_splits(obj, max_message_size):
obj_bytes = serialize(obj, protocol=4)
byte_size = len(obj_bytes)
num_slice = (byte_size - 1) // max_message_size + 1
if num_slice <= 1:
return obj, num_slice
else:
_max_size = max_message_size
kv = [(i, obj_bytes[slice(i * _max_size, (i + 1) * _max_size)]) for i in range(num_slice)]
return kv, num_slice
class Federation(object):
def _federation_object_key(self, name, tag, s_party, d_party):
return f"{self._session_id}-{name}-{tag}-{s_party.role}-{s_party.party_id}-{d_party.role}-{d_party.party_id}"
def __init__(self, session: Session, session_id, party: Party):
self._session_id = session_id
self._party: Party = party
self._session = session
self._max_message_size = DEFAULT_MESSAGE_MAX_SIZE
self._other_status_tables = {}
self._other_object_tables = {}
self._even_loop = None
self._federation_status_table_cache = None
self._federation_object_table_cache = None
def destroy(self):
self._session.cleanup(namespace=self._session_id, name="*")
@property
def _federation_status_table(self):
if self._federation_status_table_cache is None:
self._federation_status_table_cache = _create_table(
session=self._session,
name=self._get_status_table_name(self._party),
namespace=self._session_id,
partitions=1,
need_cleanup=True,
error_if_exist=False,
)
return self._federation_status_table_cache
@property
def _federation_object_table(self):
if self._federation_object_table_cache is None:
self._federation_object_table_cache = _create_table(
session=self._session,
name=self._get_object_table_name(self._party),
namespace=self._session_id,
partitions=1,
need_cleanup=True,
error_if_exist=False,
)
return self._federation_object_table_cache
@property
def _loop(self):
if self._even_loop is None:
self._even_loop = asyncio.get_event_loop()
return self._even_loop
@staticmethod
def _get_status_table_name(party):
return f"__federation_status__.{party.role}_{party.party_id}"
@staticmethod
def _get_object_table_name(party):
return f"__federation_object__.{party.role}_{party.party_id}"
def _get_other_status_table(self, party):
if party in self._other_status_tables:
return self._other_status_tables[party]
table = _create_table(
self._session,
name=self._get_status_table_name(party),
namespace=self._session_id,
partitions=1,
need_cleanup=False,
error_if_exist=False,
)
self._other_status_tables[party] = table
return table
def _get_other_object_table(self, party):
if party in self._other_object_tables:
return self._other_object_tables[party]
table = _create_table(
self._session,
name=self._get_object_table_name(party),
namespace=self._session_id,
partitions=1,
need_cleanup=False,
error_if_exist=False,
)
self._other_object_tables[party] = table
return table
# noinspection PyProtectedMember
def _put_status(self, party, _tagged_key, value):
self._get_other_status_table(party).put(_tagged_key, value)
# noinspection PyProtectedMember
def _put_object(self, party, _tagged_key, value):
self._get_other_object_table(party).put(_tagged_key, value)
# noinspection PyProtectedMember
def _get_object(self, _tagged_key):
return self._federation_object_table.get(_tagged_key)
# noinspection PyProtectedMember
def _get_status(self, _tagged_key):
return self._federation_status_table.get(_tagged_key)
# noinspection PyUnusedLocal
def remote(self, v, name: str, tag: str, parties: typing.List[Party]):
log_str = f"federation.standalone.remote.{name}.{tag}"
if v is None:
raise ValueError(f"[{log_str}]remote `None` to {parties}")
LOGGER.debug(f"[{log_str}]remote data, type={type(v)}")
if isinstance(v, Table):
dtype = FederationDataType.TABLE
LOGGER.debug(
f"[{log_str}]remote "
f"Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}), dtype={dtype}"
)
else:
v_splits, num_slice = _get_splits(v, self._max_message_size)
if num_slice > 1:
v = _create_table(
session=self._session,
name=str(uuid.uuid1()),
namespace=self._session_id,
partitions=1,
need_cleanup=True,
error_if_exist=False,
)
v.put_all(kv_list=v_splits)
dtype = FederationDataType.SPLIT_OBJECT
LOGGER.debug(
f"[{log_str}]remote "
f"Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}), dtype={dtype}"
)
else:
LOGGER.debug(f"[{log_str}]remote object with type: {type(v)}")
dtype = FederationDataType.OBJECT
for party in parties:
_tagged_key = self._federation_object_key(name, tag, self._party, party)
if isinstance(v, Table):
saved_name = str(uuid.uuid1())
LOGGER.debug(
f"[{log_str}]save Table(namespace={v.namespace}, name={v.name}, partitions={v.partitions}) as "
f"Table(namespace={v.namespace}, name={saved_name}, partitions={v.partitions})"
)
_v = v.save_as(
name=saved_name, namespace=v.namespace, need_cleanup=False
)
self._put_status(party, _tagged_key, (_v.name, _v.namespace, dtype))
else:
self._put_object(party, _tagged_key, v)
self._put_status(party, _tagged_key, _tagged_key)
# noinspection PyProtectedMember
def get(self, name: str, tag: str, parties: typing.List[Party]) -> typing.List:
log_str = f"federation.standalone.get.{name}.{tag}"
LOGGER.debug(f"[{log_str}]")
tasks = []
for party in parties:
_tagged_key = self._federation_object_key(name, tag, party, self._party)
tasks.append(_check_status_and_get_value(self._get_status, _tagged_key))
results = self._loop.run_until_complete(asyncio.gather(*tasks))
rtn = []
for r in results:
if isinstance(r, tuple):
# noinspection PyTypeChecker
table: Table = _load_table(
session=self._session, name=r[0], namespace=r[1], need_cleanup=True
)
dtype = r[2]
LOGGER.debug(
f"[{log_str}] got "
f"Table(namespace={table.namespace}, name={table.name}, partitions={table.partitions}), dtype={dtype}")
if dtype == FederationDataType.SPLIT_OBJECT:
obj_bytes = b''.join(map(lambda t: t[1], sorted(table.collect(), key=lambda x: x[0])))
obj = deserialize(obj_bytes)
rtn.append(obj)
else:
rtn.append(table)
else:
obj = self._get_object(r)
if obj is None:
raise EnvironmentError(
f"federation get None from {parties} with name {name}, tag {tag}"
)
rtn.append(obj)
self._federation_object_table.delete(k=r)
LOGGER.debug(f"[{log_str}] got object with type: {type(obj)}")
self._federation_status_table.delete(r)
return rtn
_meta_table: typing.Optional[Table] = None
_SESSION = Session(uuid.uuid1().hex)
def _get_meta_table():
global _meta_table
if _meta_table is None:
_meta_table = Table(
_SESSION,
namespace="__META__",
name="fragments",
partitions=10,
need_cleanup=False,
)
return _meta_table
# noinspection PyProtectedMember
def _get_from_meta_table(key):
return _get_meta_table().get(key)
# noinspection PyProtectedMember
def _put_to_meta_table(key, value):
_get_meta_table().put(key, value)
_data_dir = Path(file_utils.get_project_base_directory()).joinpath("data").absolute()
def _get_data_dir():
return _data_dir
def _get_storage_dir(*args):
return _data_dir.joinpath(*args)
async def _check_status_and_get_value(get_func, key):
value = get_func(key)
while value is None:
await asyncio.sleep(0.1)
value = get_func(key)
LOGGER.debug(
"[GET] Got {} type {}".format(
key, "Table" if isinstance(value, tuple) else "Object"
)
)
return value
def _create_table(
session: "Session",
name: str,
namespace: str,
partitions: int,
need_cleanup=True,
error_if_exist=False,
):
if isinstance(namespace, int):
raise ValueError(f"{namespace} {name}")
_table_key = ".".join([namespace, name])
if _get_from_meta_table(_table_key) is not None:
if error_if_exist:
raise RuntimeError(
f"table already exist: name={name}, namespace={namespace}"
)
else:
partitions = _get_from_meta_table(_table_key)
else:
_put_to_meta_table(_table_key, partitions)
return Table(
session=session,
namespace=namespace,
name=name,
partitions=partitions,
need_cleanup=need_cleanup,
)
def _exist(name: str, namespace: str):
_table_key = ".".join([namespace, name])
return _get_from_meta_table(_table_key) is not None
def _load_table(session, name, namespace, need_cleanup=False):
_table_key = ".".join([namespace, name])
partitions = _get_from_meta_table(_table_key)
if partitions is None:
raise RuntimeError(f"table not exist: name={name}, namespace={namespace}")
return Table(
session=session,
namespace=namespace,
name=name,
partitions=partitions,
need_cleanup=need_cleanup,
)
class _TaskInfo:
def __init__(self, task_id, function_id, function_bytes):
self.task_id = task_id
self.function_id = function_id
self.function_bytes = function_bytes
self._function_deserialized = None
def get_func(self):
if self._function_deserialized is None:
self._function_deserialized = f_pickle.loads(self.function_bytes)
return self._function_deserialized
class _MapReduceTaskInfo:
def __init__(self, task_id, function_id, map_function_bytes, reduce_function_bytes):
self.task_id = task_id
self.function_id = function_id
self.map_function_bytes = map_function_bytes
self.reduce_function_bytes = reduce_function_bytes
self._reduce_function_deserialized = None
self._mapper_function_deserialized = None
def get_mapper(self):
if self._mapper_function_deserialized is None:
self._mapper_function_deserialized = f_pickle.loads(self.map_function_bytes)
return self._mapper_function_deserialized
def get_reducer(self):
if self._reduce_function_deserialized is None:
self._reduce_function_deserialized = f_pickle.loads(
self.reduce_function_bytes
)
return self._reduce_function_deserialized
class _Operand:
def __init__(self, namespace, name, partition):
self.namespace = namespace
self.name = name
self.partition = partition
def as_env(self, write=False):
return _get_env(self.namespace, self.name, str(self.partition), write=write)
class _UnaryProcess:
def __init__(self, task_info: _TaskInfo, operand: _Operand):
self.info = task_info
self.operand = operand
def output_operand(self):
return _Operand(
self.info.task_id, self.info.function_id, self.operand.partition
)
def get_func(self):
return self.info.get_func()
class _MapReduceProcess:
def __init__(self, task_info: _MapReduceTaskInfo, operand: _Operand):
self.info = task_info
self.operand = operand
def output_operand(self):
return _Operand(
self.info.task_id, self.info.function_id, self.operand.partition
)
def get_mapper(self):
return self.info.get_mapper()
def get_reducer(self):
return self.info.get_reducer()
class _BinaryProcess:
def __init__(self, task_info: _TaskInfo, left: _Operand, right: _Operand):
self.info = task_info
self.left = left
self.right = right
def output_operand(self):
return _Operand(self.info.task_id, self.info.function_id, self.left.partition)
def get_func(self):
return self.info.get_func()
def _get_env(*args, write=False):
_path = _get_storage_dir(*args)
return _open_env(_path, write=write)
def _open_env(path, write=False):
path.mkdir(parents=True, exist_ok=True)
t = 0
while t < 100:
try:
env = lmdb.open(
path.as_posix(),
create=True,
max_dbs=1,
max_readers=1024,
lock=write,
sync=True,
map_size=10_737_418_240,
)
return env
except lmdb.Error as e:
if "No such file or directory" in e.args[0]:
time.sleep(0.01)
t += 1
else:
raise e
raise lmdb.Error(f"No such file or directory: {path}, with {t} times retry")
def _hash_key_to_partition(key, partitions):
_key = hashlib.sha1(key).digest()
if isinstance(_key, bytes):
_key = int.from_bytes(_key, byteorder="little", signed=False)
if partitions < 1:
raise ValueError("partitions must be a positive number")
b, j = -1, 0
while j < partitions:
b = int(j)
_key = ((_key * 2862933555777941757) + 1) & 0xFFFFFFFFFFFFFFFF
j = float(b + 1) * (float(1 << 31) / float((_key >> 33) + 1))
return int(b)
def _do_map(p: _UnaryProcess):
rtn = p.output_operand()
with ExitStack() as s:
source_env = s.enter_context(p.operand.as_env())
partitions = _get_from_meta_table(f"{p.operand.namespace}.{p.operand.name}")
txn_map = {}
for partition in range(partitions):
env = s.enter_context(
_get_env(rtn.namespace, rtn.name, str(partition), write=True)
)
txn_map[partition] = s.enter_context(env.begin(write=True))
source_txn = s.enter_context(source_env.begin())
cursor = s.enter_context(source_txn.cursor())
for k_bytes, v_bytes in cursor:
k, v = deserialize(k_bytes), deserialize(v_bytes)
k1, v1 = p.get_func()(k, v)
k1_bytes, v1_bytes = serialize(k1), serialize(v1)
partition = _hash_key_to_partition(k1_bytes, partitions)
txn_map[partition].put(k1_bytes, v1_bytes)
return rtn
def _generator_from_cursor(cursor):
for k, v in cursor:
yield deserialize(k), deserialize(v)
def _do_apply_partitions(p: _UnaryProcess):
with ExitStack() as s:
rtn = p.output_operand()
source_env = s.enter_context(p.operand.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
source_txn = s.enter_context(source_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(source_txn.cursor())
v = p.get_func()(_generator_from_cursor(cursor))
if cursor.last():
k_bytes = cursor.key()
dst_txn.put(k_bytes, serialize(v))
return rtn
def _do_map_partitions(p: _UnaryProcess):
with ExitStack() as s:
rtn = p.output_operand()
source_env = s.enter_context(p.operand.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
source_txn = s.enter_context(source_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(source_txn.cursor())
v = p.get_func()(_generator_from_cursor(cursor))
if isinstance(v, Iterable):
for k1, v1 in v:
dst_txn.put(serialize(k1), serialize(v1))
else:
k_bytes = cursor.key()
dst_txn.put(k_bytes, serialize(v))
return rtn
def _do_map_partitions_with_index(p: _UnaryProcess):
with ExitStack() as s:
rtn = p.output_operand()
source_env = s.enter_context(p.operand.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
source_txn = s.enter_context(source_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(source_txn.cursor())
v = p.get_func()(p.operand.partition, _generator_from_cursor(cursor))
if isinstance(v, Iterable):
for k1, v1 in v:
dst_txn.put(serialize(k1), serialize(v1))
else:
k_bytes = cursor.key()
dst_txn.put(k_bytes, serialize(v))
return rtn
def _do_map_reduce_in_partitions(p: _MapReduceProcess):
rtn = p.output_operand()
with ExitStack() as s:
source_env = s.enter_context(p.operand.as_env())
partitions = _get_from_meta_table(f"{p.operand.namespace}.{p.operand.name}")
txn_map = {}
for partition in range(partitions):
env = s.enter_context(
_get_env(rtn.namespace, rtn.name, str(partition), write=True)
)
txn_map[partition] = s.enter_context(env.begin(write=True))
source_txn = s.enter_context(source_env.begin())
cursor = s.enter_context(source_txn.cursor())
mapped = p.get_mapper()(_generator_from_cursor(cursor))
if not isinstance(mapped, Iterable):
raise ValueError("mapper function should return a iterable of pair")
reducer = p.get_reducer()
for k, v in mapped:
k_bytes = serialize(k)
partition = _hash_key_to_partition(k_bytes, partitions)
# todo: not atomic, fix me
pre_v = txn_map[partition].get(k_bytes, None)
if pre_v is None:
txn_map[partition].put(k_bytes, serialize(v))
else:
txn_map[partition].put(
k_bytes, serialize(reducer(deserialize(pre_v), v))
)
return rtn
def _do_map_values(p: _UnaryProcess):
rtn = p.output_operand()
with ExitStack() as s:
source_env = s.enter_context(p.operand.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
source_txn = s.enter_context(source_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(source_txn.cursor())
for k_bytes, v_bytes in cursor:
v = deserialize(v_bytes)
v1 = p.get_func()(v)
dst_txn.put(k_bytes, serialize(v1))
return rtn
def _do_flat_map(p: _UnaryProcess):
rtn = p.output_operand()
with ExitStack() as s:
source_env = s.enter_context(p.operand.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
source_txn = s.enter_context(source_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(source_txn.cursor())
for k_bytes, v_bytes in cursor:
k = deserialize(k_bytes)
v = deserialize(v_bytes)
map_result = p.get_func()(k, v)
for result_k, result_v in map_result:
dst_txn.put(serialize(result_k), serialize(result_v))
return rtn
def _do_reduce(p: _UnaryProcess):
value = None
with ExitStack() as s:
source_env = s.enter_context(p.operand.as_env())
source_txn = s.enter_context(source_env.begin())
cursor = s.enter_context(source_txn.cursor())
for k_bytes, v_bytes in cursor:
v = deserialize(v_bytes)
if value is None:
value = v
else:
value = p.get_func()(value, v)
return value
def _do_glom(p: _UnaryProcess):
rtn = p.output_operand()
with ExitStack() as s:
source_env = s.enter_context(p.operand.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
source_txn = s.enter_context(source_env.begin())
dest_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(source_txn.cursor())
v_list = []
k_bytes = None
for k, v in cursor:
v_list.append((deserialize(k), deserialize(v)))
k_bytes = k
if k_bytes is not None:
dest_txn.put(k_bytes, serialize(v_list))
return rtn
def _do_sample(p: _UnaryProcess):
rtn = p.output_operand()
fraction, seed = deserialize(p.info.function_bytes)
with ExitStack() as s:
source_env = s.enter_context(p.operand.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
source_txn = s.enter_context(source_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(source_txn.cursor())
cursor.first()
random_state = np.random.RandomState(seed)
for k, v in cursor:
# noinspection PyArgumentList
if random_state.rand() < fraction:
dst_txn.put(k, v)
return rtn
def _do_filter(p: _UnaryProcess):
rtn = p.output_operand()
with ExitStack() as s:
source_env = s.enter_context(p.operand.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
source_txn = s.enter_context(source_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(source_txn.cursor())
for k_bytes, v_bytes in cursor:
k = deserialize(k_bytes)
v = deserialize(v_bytes)
if p.get_func()(k, v):
dst_txn.put(k_bytes, v_bytes)
return rtn
def _do_subtract_by_key(p: _BinaryProcess):
rtn = p.output_operand()
with ExitStack() as s:
left_op = p.left
right_op = p.right
right_env = s.enter_context(right_op.as_env())
left_env = s.enter_context(left_op.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
left_txn = s.enter_context(left_env.begin())
right_txn = s.enter_context(right_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(left_txn.cursor())
for k_bytes, left_v_bytes in cursor:
right_v_bytes = right_txn.get(k_bytes)
if right_v_bytes is None:
dst_txn.put(k_bytes, left_v_bytes)
return rtn
def _do_join(p: _BinaryProcess):
rtn = p.output_operand()
with ExitStack() as s:
right_env = s.enter_context(p.right.as_env())
left_env = s.enter_context(p.left.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
left_txn = s.enter_context(left_env.begin())
right_txn = s.enter_context(right_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
cursor = s.enter_context(left_txn.cursor())
for k_bytes, v1_bytes in cursor:
v2_bytes = right_txn.get(k_bytes)
if v2_bytes is None:
continue
v1 = deserialize(v1_bytes)
v2 = deserialize(v2_bytes)
v3 = p.get_func()(v1, v2)
dst_txn.put(k_bytes, serialize(v3))
return rtn
def _do_union(p: _BinaryProcess):
rtn = p.output_operand()
with ExitStack() as s:
left_env = s.enter_context(p.left.as_env())
right_env = s.enter_context(p.right.as_env())
dst_env = s.enter_context(rtn.as_env(write=True))
left_txn = s.enter_context(left_env.begin())
right_txn = s.enter_context(right_env.begin())
dst_txn = s.enter_context(dst_env.begin(write=True))
# process left op
with left_txn.cursor() as left_cursor:
for k_bytes, left_v_bytes in left_cursor:
right_v_bytes = right_txn.get(k_bytes)
if right_v_bytes is None:
dst_txn.put(k_bytes, left_v_bytes)
else:
left_v = deserialize(left_v_bytes)
right_v = deserialize(right_v_bytes)
final_v = p.get_func()(left_v, right_v)
dst_txn.put(k_bytes, serialize(final_v))
# process right op
with right_txn.cursor() as right_cursor:
for k_bytes, right_v_bytes in right_cursor:
final_v_bytes = dst_txn.get(k_bytes)
if final_v_bytes is None:
dst_txn.put(k_bytes, right_v_bytes)
return rtn
def _kv_to_bytes(k, v):
return serialize(k), serialize(v)
def _k_to_bytes(k):
return serialize(k)