| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- #
- # Copyright 2019 The Eggroll Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import itertools
- import typing
- from fate_arch.abc import CTableABC
- from fate_arch.common import log
- from fate_arch.common.profile import computing_profile
- from fate_arch.computing._type import ComputingEngine
- LOGGER = log.getLogger()
- class Table(CTableABC):
- def __init__(self, table):
- self._table = table
- self._engine = ComputingEngine.STANDALONE
- self._count = None
- @property
- def engine(self):
- return self._engine
- def __getstate__(self):
- pass
- @property
- def partitions(self):
- return self._table.partitions
- def copy(self):
- return Table(self._table.mapValues(lambda x: x))
- @computing_profile
- def save(self, address, partitions, schema, **kwargs):
- from fate_arch.common.address import StandaloneAddress
- if isinstance(address, StandaloneAddress):
- self._table.save_as(
- name=address.name,
- namespace=address.namespace,
- partition=partitions,
- need_cleanup=False,
- )
- schema.update(self.schema)
- return
- from fate_arch.common.address import PathAddress
- if isinstance(address, PathAddress):
- from fate_arch.computing.non_distributed import LocalData
- return LocalData(address.path)
- raise NotImplementedError(
- f"address type {type(address)} not supported with standalone backend"
- )
- @computing_profile
- def count(self) -> int:
- if self._count is None:
- self._count = self._table.count()
- return self._count
- @computing_profile
- def collect(self, **kwargs):
- return self._table.collect(**kwargs)
- @computing_profile
- def take(self, n=1, **kwargs):
- return self._table.take(n=n, **kwargs)
- @computing_profile
- def first(self, **kwargs):
- resp = list(itertools.islice(self._table.collect(**kwargs), 1))
- if len(resp) < 1:
- raise RuntimeError("table is empty")
- return resp[0]
- @computing_profile
- def reduce(self, func, **kwargs):
- return self._table.reduce(func)
- @computing_profile
- def map(self, func):
- return Table(self._table.map(func))
- @computing_profile
- def mapValues(self, func):
- return Table(self._table.mapValues(func))
- @computing_profile
- def flatMap(self, func):
- return Table(self._table.flatMap(func))
- @computing_profile
- def applyPartitions(self, func):
- return Table(self._table.applyPartitions(func))
- @computing_profile
- def mapPartitions(
- self, func, use_previous_behavior=True, preserves_partitioning=False
- ):
- if use_previous_behavior is True:
- LOGGER.warning(
- "please use `applyPartitions` instead of `mapPartitions` "
- "if the previous behavior was expected. "
- "The previous behavior will not work in future"
- )
- return Table(self._table.applyPartitions(func))
- return Table(
- self._table.mapPartitions(
- func, preserves_partitioning=preserves_partitioning
- )
- )
- @computing_profile
- def mapReducePartitions(self, mapper, reducer, **kwargs):
- return Table(self._table.mapReducePartitions(mapper, reducer))
- @computing_profile
- def mapPartitionsWithIndex(self, func, preserves_partitioning=False, **kwargs):
- return Table(
- self._table.mapPartitionsWithIndex(
- func, preserves_partitioning=preserves_partitioning
- )
- )
- @computing_profile
- def glom(self):
- return Table(self._table.glom())
- @computing_profile
- def sample(
- self,
- *,
- fraction: typing.Optional[float] = None,
- num: typing.Optional[int] = None,
- seed=None,
- ):
- if fraction is not None:
- return Table(self._table.sample(fraction=fraction, seed=seed))
- if num is not None:
- total = self._table.count()
- if num > total:
- raise ValueError(
- f"not enough data to sample, own {total} but required {num}"
- )
- frac = num / float(total)
- while True:
- sampled_table = self._table.sample(fraction=frac, seed=seed)
- sampled_count = sampled_table.count()
- if sampled_count < num:
- frac += 0.1
- else:
- break
- if sampled_count > num:
- drops = sampled_table.take(sampled_count - num)
- for k, v in drops:
- sampled_table.delete(k)
- return Table(sampled_table)
- raise ValueError(
- f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}"
- )
- @computing_profile
- def filter(self, func):
- return Table(self._table.filter(func))
- @computing_profile
- def join(self, other: "Table", func):
- return Table(self._table.join(other._table, func))
- @computing_profile
- def subtractByKey(self, other: "Table"):
- return Table(self._table.subtractByKey(other._table))
- @computing_profile
- def union(self, other: "Table", func=lambda v1, v2: v1):
- return Table(self._table.union(other._table, func))
|