_table.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import typing
  17. from fate_arch.abc import CTableABC
  18. from fate_arch.common import log
  19. from fate_arch.common.profile import computing_profile
  20. from fate_arch.computing._type import ComputingEngine
  21. LOGGER = log.getLogger()
  22. class Table(CTableABC):
  23. def __init__(self, rp):
  24. self._rp = rp
  25. self._engine = ComputingEngine.EGGROLL
  26. self._count = None
  27. @property
  28. def engine(self):
  29. return self._engine
  30. @property
  31. def partitions(self):
  32. return self._rp.get_partitions()
  33. def copy(self):
  34. return Table(self._rp.map_values(lambda x: x))
  35. @computing_profile
  36. def save(self, address, partitions, schema: dict, **kwargs):
  37. options = kwargs.get("options", {})
  38. from fate_arch.common.address import EggRollAddress
  39. from fate_arch.storage import EggRollStoreType
  40. if isinstance(address, EggRollAddress):
  41. options["store_type"] = kwargs.get("store_type", EggRollStoreType.ROLLPAIR_LMDB)
  42. self._rp.save_as(name=address.name, namespace=address.namespace, partition=partitions, options=options)
  43. schema.update(self.schema)
  44. return
  45. from fate_arch.common.address import PathAddress
  46. if isinstance(address, PathAddress):
  47. from fate_arch.computing.non_distributed import LocalData
  48. return LocalData(address.path)
  49. raise NotImplementedError(f"address type {type(address)} not supported with eggroll backend")
  50. @computing_profile
  51. def collect(self, **kwargs) -> list:
  52. return self._rp.get_all()
  53. @computing_profile
  54. def count(self, **kwargs) -> int:
  55. if self._count is None:
  56. self._count = self._rp.count()
  57. return self._count
  58. @computing_profile
  59. def take(self, n=1, **kwargs):
  60. options = dict(keys_only=False)
  61. return self._rp.take(n=n, options=options)
  62. @computing_profile
  63. def first(self):
  64. options = dict(keys_only=False)
  65. return self._rp.first(options=options)
  66. @computing_profile
  67. def map(self, func, **kwargs):
  68. return Table(self._rp.map(func))
  69. @computing_profile
  70. def mapValues(self, func: typing.Callable[[typing.Any], typing.Any], **kwargs):
  71. return Table(self._rp.map_values(func))
  72. @computing_profile
  73. def applyPartitions(self, func):
  74. return Table(self._rp.collapse_partitions(func))
  75. @computing_profile
  76. def mapPartitions(self, func, use_previous_behavior=True, preserves_partitioning=False, **kwargs):
  77. if use_previous_behavior is True:
  78. LOGGER.warning(f"please use `applyPartitions` instead of `mapPartitions` "
  79. f"if the previous behavior was expected. "
  80. f"The previous behavior will not work in future")
  81. return self.applyPartitions(func)
  82. return Table(self._rp.map_partitions(func, options={"shuffle": not preserves_partitioning}))
  83. @computing_profile
  84. def mapReducePartitions(self, mapper, reducer, **kwargs):
  85. return Table(self._rp.map_partitions(func=mapper, reduce_op=reducer))
  86. @computing_profile
  87. def mapPartitionsWithIndex(self, func, preserves_partitioning=False, **kwargs):
  88. return Table(self._rp.map_partitions_with_index(func, options={"shuffle": not preserves_partitioning}))
  89. @computing_profile
  90. def reduce(self, func, **kwargs):
  91. return self._rp.reduce(func)
  92. @computing_profile
  93. def join(self, other: 'Table', func, **kwargs):
  94. return Table(self._rp.join(other._rp, func=func))
  95. @computing_profile
  96. def glom(self, **kwargs):
  97. return Table(self._rp.glom())
  98. @computing_profile
  99. def sample(self, *, fraction: typing.Optional[float] = None, num: typing.Optional[int] = None, seed=None):
  100. if fraction is not None:
  101. return Table(self._rp.sample(fraction=fraction, seed=seed))
  102. if num is not None:
  103. total = self._rp.count()
  104. if num > total:
  105. raise ValueError(f"not enough data to sample, own {total} but required {num}")
  106. frac = num / float(total)
  107. while True:
  108. sampled_table = self._rp.sample(fraction=frac, seed=seed)
  109. sampled_count = sampled_table.count()
  110. if sampled_count < num:
  111. frac *= 1.1
  112. else:
  113. break
  114. if sampled_count > num:
  115. drops = sampled_table.take(sampled_count - num)
  116. for k, v in drops:
  117. sampled_table.delete(k)
  118. return Table(sampled_table)
  119. raise ValueError(f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}")
  120. @computing_profile
  121. def subtractByKey(self, other: 'Table', **kwargs):
  122. return Table(self._rp.subtract_by_key(other._rp))
  123. @computing_profile
  124. def filter(self, func, **kwargs):
  125. return Table(self._rp.filter(func))
  126. @computing_profile
  127. def union(self, other: 'Table', func=lambda v1, v2: v1, **kwargs):
  128. return Table(self._rp.union(other._rp, func=func))
  129. @computing_profile
  130. def flatMap(self, func, **kwargs):
  131. flat_map = self._rp.flat_map(func)
  132. shuffled = flat_map.map(lambda k, v: (k, v)) # trigger shuffle
  133. return Table(shuffled)