_table.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. #
  2. # Copyright 2019 The Eggroll Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import itertools
  17. import typing
  18. from fate_arch.abc import CTableABC
  19. from fate_arch.common import log
  20. from fate_arch.common.profile import computing_profile
  21. from fate_arch.computing._type import ComputingEngine
  22. LOGGER = log.getLogger()
  23. class Table(CTableABC):
  24. def __init__(self, table):
  25. self._table = table
  26. self._engine = ComputingEngine.STANDALONE
  27. self._count = None
  28. @property
  29. def engine(self):
  30. return self._engine
  31. def __getstate__(self):
  32. pass
  33. @property
  34. def partitions(self):
  35. return self._table.partitions
  36. def copy(self):
  37. return Table(self._table.mapValues(lambda x: x))
  38. @computing_profile
  39. def save(self, address, partitions, schema, **kwargs):
  40. from fate_arch.common.address import StandaloneAddress
  41. if isinstance(address, StandaloneAddress):
  42. self._table.save_as(
  43. name=address.name,
  44. namespace=address.namespace,
  45. partition=partitions,
  46. need_cleanup=False,
  47. )
  48. schema.update(self.schema)
  49. return
  50. from fate_arch.common.address import PathAddress
  51. if isinstance(address, PathAddress):
  52. from fate_arch.computing.non_distributed import LocalData
  53. return LocalData(address.path)
  54. raise NotImplementedError(
  55. f"address type {type(address)} not supported with standalone backend"
  56. )
  57. @computing_profile
  58. def count(self) -> int:
  59. if self._count is None:
  60. self._count = self._table.count()
  61. return self._count
  62. @computing_profile
  63. def collect(self, **kwargs):
  64. return self._table.collect(**kwargs)
  65. @computing_profile
  66. def take(self, n=1, **kwargs):
  67. return self._table.take(n=n, **kwargs)
  68. @computing_profile
  69. def first(self, **kwargs):
  70. resp = list(itertools.islice(self._table.collect(**kwargs), 1))
  71. if len(resp) < 1:
  72. raise RuntimeError("table is empty")
  73. return resp[0]
  74. @computing_profile
  75. def reduce(self, func, **kwargs):
  76. return self._table.reduce(func)
  77. @computing_profile
  78. def map(self, func):
  79. return Table(self._table.map(func))
  80. @computing_profile
  81. def mapValues(self, func):
  82. return Table(self._table.mapValues(func))
  83. @computing_profile
  84. def flatMap(self, func):
  85. return Table(self._table.flatMap(func))
  86. @computing_profile
  87. def applyPartitions(self, func):
  88. return Table(self._table.applyPartitions(func))
  89. @computing_profile
  90. def mapPartitions(
  91. self, func, use_previous_behavior=True, preserves_partitioning=False
  92. ):
  93. if use_previous_behavior is True:
  94. LOGGER.warning(
  95. "please use `applyPartitions` instead of `mapPartitions` "
  96. "if the previous behavior was expected. "
  97. "The previous behavior will not work in future"
  98. )
  99. return Table(self._table.applyPartitions(func))
  100. return Table(
  101. self._table.mapPartitions(
  102. func, preserves_partitioning=preserves_partitioning
  103. )
  104. )
  105. @computing_profile
  106. def mapReducePartitions(self, mapper, reducer, **kwargs):
  107. return Table(self._table.mapReducePartitions(mapper, reducer))
  108. @computing_profile
  109. def mapPartitionsWithIndex(self, func, preserves_partitioning=False, **kwargs):
  110. return Table(
  111. self._table.mapPartitionsWithIndex(
  112. func, preserves_partitioning=preserves_partitioning
  113. )
  114. )
  115. @computing_profile
  116. def glom(self):
  117. return Table(self._table.glom())
  118. @computing_profile
  119. def sample(
  120. self,
  121. *,
  122. fraction: typing.Optional[float] = None,
  123. num: typing.Optional[int] = None,
  124. seed=None,
  125. ):
  126. if fraction is not None:
  127. return Table(self._table.sample(fraction=fraction, seed=seed))
  128. if num is not None:
  129. total = self._table.count()
  130. if num > total:
  131. raise ValueError(
  132. f"not enough data to sample, own {total} but required {num}"
  133. )
  134. frac = num / float(total)
  135. while True:
  136. sampled_table = self._table.sample(fraction=frac, seed=seed)
  137. sampled_count = sampled_table.count()
  138. if sampled_count < num:
  139. frac += 0.1
  140. else:
  141. break
  142. if sampled_count > num:
  143. drops = sampled_table.take(sampled_count - num)
  144. for k, v in drops:
  145. sampled_table.delete(k)
  146. return Table(sampled_table)
  147. raise ValueError(
  148. f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}"
  149. )
  150. @computing_profile
  151. def filter(self, func):
  152. return Table(self._table.filter(func))
  153. @computing_profile
  154. def join(self, other: "Table", func):
  155. return Table(self._table.join(other._table, func))
  156. @computing_profile
  157. def subtractByKey(self, other: "Table"):
  158. return Table(self._table.subtractByKey(other._table))
  159. @computing_profile
  160. def union(self, other: "Table", func=lambda v1, v2: v1):
  161. return Table(self._table.union(other._table, func))