| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import typing
- from fate_arch.abc import CTableABC
- from fate_arch.common import log
- from fate_arch.common.profile import computing_profile
- from fate_arch.computing._type import ComputingEngine
- LOGGER = log.getLogger()
- class Table(CTableABC):
- def __init__(self, rp):
- self._rp = rp
- self._engine = ComputingEngine.EGGROLL
- self._count = None
- @property
- def engine(self):
- return self._engine
- @property
- def partitions(self):
- return self._rp.get_partitions()
- def copy(self):
- return Table(self._rp.map_values(lambda x: x))
- @computing_profile
- def save(self, address, partitions, schema: dict, **kwargs):
- options = kwargs.get("options", {})
- from fate_arch.common.address import EggRollAddress
- from fate_arch.storage import EggRollStoreType
- if isinstance(address, EggRollAddress):
- options["store_type"] = kwargs.get("store_type", EggRollStoreType.ROLLPAIR_LMDB)
- self._rp.save_as(name=address.name, namespace=address.namespace, partition=partitions, options=options)
- schema.update(self.schema)
- return
- from fate_arch.common.address import PathAddress
- if isinstance(address, PathAddress):
- from fate_arch.computing.non_distributed import LocalData
- return LocalData(address.path)
- raise NotImplementedError(f"address type {type(address)} not supported with eggroll backend")
- @computing_profile
- def collect(self, **kwargs) -> list:
- return self._rp.get_all()
- @computing_profile
- def count(self, **kwargs) -> int:
- if self._count is None:
- self._count = self._rp.count()
- return self._count
- @computing_profile
- def take(self, n=1, **kwargs):
- options = dict(keys_only=False)
- return self._rp.take(n=n, options=options)
- @computing_profile
- def first(self):
- options = dict(keys_only=False)
- return self._rp.first(options=options)
- @computing_profile
- def map(self, func, **kwargs):
- return Table(self._rp.map(func))
- @computing_profile
- def mapValues(self, func: typing.Callable[[typing.Any], typing.Any], **kwargs):
- return Table(self._rp.map_values(func))
- @computing_profile
- def applyPartitions(self, func):
- return Table(self._rp.collapse_partitions(func))
- @computing_profile
- def mapPartitions(self, func, use_previous_behavior=True, preserves_partitioning=False, **kwargs):
- if use_previous_behavior is True:
- LOGGER.warning(f"please use `applyPartitions` instead of `mapPartitions` "
- f"if the previous behavior was expected. "
- f"The previous behavior will not work in future")
- return self.applyPartitions(func)
- return Table(self._rp.map_partitions(func, options={"shuffle": not preserves_partitioning}))
- @computing_profile
- def mapReducePartitions(self, mapper, reducer, **kwargs):
- return Table(self._rp.map_partitions(func=mapper, reduce_op=reducer))
- @computing_profile
- def mapPartitionsWithIndex(self, func, preserves_partitioning=False, **kwargs):
- return Table(self._rp.map_partitions_with_index(func, options={"shuffle": not preserves_partitioning}))
- @computing_profile
- def reduce(self, func, **kwargs):
- return self._rp.reduce(func)
- @computing_profile
- def join(self, other: 'Table', func, **kwargs):
- return Table(self._rp.join(other._rp, func=func))
- @computing_profile
- def glom(self, **kwargs):
- return Table(self._rp.glom())
- @computing_profile
- def sample(self, *, fraction: typing.Optional[float] = None, num: typing.Optional[int] = None, seed=None):
- if fraction is not None:
- return Table(self._rp.sample(fraction=fraction, seed=seed))
- if num is not None:
- total = self._rp.count()
- if num > total:
- raise ValueError(f"not enough data to sample, own {total} but required {num}")
- frac = num / float(total)
- while True:
- sampled_table = self._rp.sample(fraction=frac, seed=seed)
- sampled_count = sampled_table.count()
- if sampled_count < num:
- frac *= 1.1
- else:
- break
- if sampled_count > num:
- drops = sampled_table.take(sampled_count - num)
- for k, v in drops:
- sampled_table.delete(k)
- return Table(sampled_table)
- raise ValueError(f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}")
- @computing_profile
- def subtractByKey(self, other: 'Table', **kwargs):
- return Table(self._rp.subtract_by_key(other._rp))
- @computing_profile
- def filter(self, func, **kwargs):
- return Table(self._rp.filter(func))
- @computing_profile
- def union(self, other: 'Table', func=lambda v1, v2: v1, **kwargs):
- return Table(self._rp.union(other._rp, func=func))
- @computing_profile
- def flatMap(self, func, **kwargs):
- flat_map = self._rp.flat_map(func)
- shuffled = flat_map.map(lambda k, v: (k, v)) # trigger shuffle
- return Table(shuffled)
|