_table.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import uuid
  17. from itertools import chain
  18. import typing
  19. import pyspark
  20. from pyspark.rddsampler import RDDSamplerBase
  21. from fate_arch.abc import CTableABC
  22. from fate_arch.common import log, hdfs_utils, hive_utils
  23. from fate_arch.common.profile import computing_profile
  24. from fate_arch.computing.spark._materialize import materialize, unmaterialize
  25. from scipy.stats import hypergeom
  26. from fate_arch.computing._type import ComputingEngine
  27. LOGGER = log.getLogger()
  28. class Table(CTableABC):
  29. def __init__(self, rdd):
  30. self._rdd: pyspark.RDD = rdd
  31. self._engine = ComputingEngine.SPARK
  32. self._count = None
  33. @property
  34. def engine(self):
  35. return self._engine
  36. def __getstate__(self):
  37. pass
  38. def __del__(self):
  39. try:
  40. unmaterialize(self._rdd)
  41. del self._rdd
  42. except BaseException:
  43. return
  44. def copy(self):
  45. """rdd is immutable, yet, inside content could be modify in some case"""
  46. return Table(_map_value(self._rdd, lambda x: x))
  47. @computing_profile
  48. def save(self, address, partitions, schema, **kwargs):
  49. from fate_arch.common.address import HDFSAddress
  50. if isinstance(address, HDFSAddress):
  51. self._rdd.map(lambda x: hdfs_utils.serialize(x[0], x[1])).repartition(
  52. partitions
  53. ).saveAsTextFile(f"{address.name_node}/{address.path}")
  54. schema.update(self.schema)
  55. return
  56. from fate_arch.common.address import HiveAddress, LinkisHiveAddress
  57. if isinstance(address, (HiveAddress, LinkisHiveAddress)):
  58. # df = (
  59. # self._rdd.map(lambda x: hive_utils.to_row(x[0], x[1]))
  60. # .repartition(partitions)
  61. # .toDF()
  62. # )
  63. LOGGER.debug(f"partitions: {partitions}")
  64. _repartition = self._rdd.map(lambda x: hive_utils.to_row(x[0], x[1])).repartition(partitions)
  65. _repartition.toDF().write.saveAsTable(f"{address.database}.{address.name}")
  66. schema.update(self.schema)
  67. return
  68. from fate_arch.common.address import LocalFSAddress
  69. if isinstance(address, LocalFSAddress):
  70. self._rdd.map(lambda x: hdfs_utils.serialize(x[0], x[1])).repartition(
  71. partitions
  72. ).saveAsTextFile(address.path)
  73. schema.update(self.schema)
  74. return
  75. raise NotImplementedError(
  76. f"address type {type(address)} not supported with spark backend"
  77. )
  78. @property
  79. def partitions(self):
  80. return self._rdd.getNumPartitions()
  81. @computing_profile
  82. def map(self, func, **kwargs):
  83. return from_rdd(_map(self._rdd, func))
  84. @computing_profile
  85. def mapValues(self, func, **kwargs):
  86. return from_rdd(_map_value(self._rdd, func))
  87. @computing_profile
  88. def mapPartitions(
  89. self, func, use_previous_behavior=True, preserves_partitioning=False, **kwargs
  90. ):
  91. if use_previous_behavior is True:
  92. LOGGER.warning(
  93. f"please use `applyPartitions` instead of `mapPartitions` "
  94. f"if the previous behavior was expected. "
  95. f"The previous behavior will not work in future"
  96. )
  97. return self.applyPartitions(func)
  98. return from_rdd(
  99. self._rdd.mapPartitions(func, preservesPartitioning=preserves_partitioning)
  100. )
  101. @computing_profile
  102. def mapReducePartitions(self, mapper, reducer, **kwargs):
  103. return from_rdd(self._rdd.mapPartitions(mapper).reduceByKey(reducer))
  104. @computing_profile
  105. def applyPartitions(self, func, **kwargs):
  106. return from_rdd(_map_partitions(self._rdd, func))
  107. @computing_profile
  108. def mapPartitionsWithIndex(self, func, preserves_partitioning=False, **kwargs):
  109. return from_rdd(
  110. self._rdd.mapPartitionsWithIndex(func, preservesPartitioning=preserves_partitioning)
  111. )
  112. @computing_profile
  113. def glom(self, **kwargs):
  114. return from_rdd(_glom(self._rdd))
  115. @computing_profile
  116. def sample(
  117. self,
  118. *,
  119. fraction: typing.Optional[float] = None,
  120. num: typing.Optional[int] = None,
  121. seed=None,
  122. ):
  123. if fraction is not None:
  124. return from_rdd(
  125. self._rdd.sample(fraction=fraction, withReplacement=False, seed=seed)
  126. )
  127. if num is not None:
  128. return from_rdd(_exactly_sample(self._rdd, num, seed=seed))
  129. raise ValueError(
  130. f"exactly one of `fraction` or `num` required, fraction={fraction}, num={num}"
  131. )
  132. @computing_profile
  133. def filter(self, func, **kwargs):
  134. return from_rdd(_filter(self._rdd, func))
  135. @computing_profile
  136. def flatMap(self, func, **kwargs):
  137. return from_rdd(_flat_map(self._rdd, func))
  138. @computing_profile
  139. def reduce(self, func, **kwargs):
  140. return self._rdd.values().reduce(func)
  141. @computing_profile
  142. def collect(self, **kwargs):
  143. # return iter(self._rdd.collect())
  144. return self._rdd.toLocalIterator()
  145. @computing_profile
  146. def take(self, n=1, **kwargs):
  147. _value = self._rdd.take(n)
  148. if kwargs.get("filter", False):
  149. self._rdd = self._rdd.filter(lambda xy: xy not in [_xy for _xy in _value])
  150. return _value
  151. @computing_profile
  152. def first(self, **kwargs):
  153. return self.take(1)[0]
  154. @computing_profile
  155. def count(self, **kwargs):
  156. if self._count is None:
  157. self._count = self._rdd.count()
  158. return self._count
  159. @computing_profile
  160. def join(self, other: "Table", func=None, **kwargs):
  161. return from_rdd(_join(self._rdd, other._rdd, func=func))
  162. @computing_profile
  163. def subtractByKey(self, other: "Table", **kwargs):
  164. return from_rdd(_subtract_by_key(self._rdd, other._rdd))
  165. @computing_profile
  166. def union(self, other: "Table", func=None, **kwargs):
  167. return from_rdd(_union(self._rdd, other._rdd, func))
  168. def from_hdfs(paths: str, partitions, in_serialized=True, id_delimiter=None):
  169. # noinspection PyPackageRequirements
  170. from pyspark import SparkContext
  171. sc = SparkContext.getOrCreate()
  172. fun = hdfs_utils.deserialize if in_serialized else lambda x: (x.partition(id_delimiter)[0],
  173. x.partition(id_delimiter)[2])
  174. rdd = materialize(
  175. sc.textFile(paths, partitions)
  176. .map(fun)
  177. .repartition(partitions)
  178. )
  179. return Table(rdd=rdd)
  180. def from_localfs(paths: str, partitions, in_serialized=True, id_delimiter=None):
  181. # noinspection PyPackageRequirements
  182. from pyspark import SparkContext
  183. sc = SparkContext.getOrCreate()
  184. fun = hdfs_utils.deserialize if in_serialized else lambda x: (x.partition(id_delimiter)[0],
  185. x.partition(id_delimiter)[2])
  186. rdd = materialize(
  187. sc.textFile(paths, partitions)
  188. .map(fun)
  189. .repartition(partitions)
  190. )
  191. return Table(rdd=rdd)
  192. def from_hive(tb_name, db_name, partitions):
  193. from pyspark.sql import SparkSession
  194. session = SparkSession.builder.enableHiveSupport().getOrCreate()
  195. rdd = materialize(
  196. session.sql(f"select * from {db_name}.{tb_name}")
  197. .rdd.map(hive_utils.from_row)
  198. .repartition(partitions)
  199. )
  200. return Table(rdd=rdd)
  201. def from_rdd(rdd):
  202. rdd = materialize(rdd)
  203. return Table(rdd=rdd)
  204. def _fail_on_stopiteration(fn):
  205. # noinspection PyPackageRequirements
  206. from pyspark import util
  207. return util.fail_on_stopiteration(fn)
  208. def _map(rdd, func):
  209. def _fn(x):
  210. return func(x[0], x[1])
  211. def _func(_, iterator):
  212. return map(_fail_on_stopiteration(_fn), iterator)
  213. return rdd.mapPartitionsWithIndex(_func, preservesPartitioning=False)
  214. def _map_value(rdd, func):
  215. def _fn(x):
  216. return x[0], func(x[1])
  217. def _func(_, iterator):
  218. return map(_fail_on_stopiteration(_fn), iterator)
  219. return rdd.mapPartitionsWithIndex(_func, preservesPartitioning=True)
  220. def _map_partitions(rdd, func):
  221. def _func(_, iterator):
  222. return [(str(uuid.uuid1()), func(iterator))]
  223. return rdd.mapPartitionsWithIndex(_func, preservesPartitioning=False)
  224. def _join(rdd, other, func=None):
  225. num_partitions = max(rdd.getNumPartitions(), other.getNumPartitions())
  226. rtn_rdd = rdd.join(other, numPartitions=num_partitions)
  227. if func is not None:
  228. rtn_rdd = _map_value(rtn_rdd, lambda x: func(x[0], x[1]))
  229. return rtn_rdd
  230. def _glom(rdd):
  231. def _func(_, iterator):
  232. yield list(iterator)
  233. return rdd.mapPartitionsWithIndex(_func)
  234. def _exactly_sample(rdd, num: int, seed: int):
  235. split_size = rdd.mapPartitionsWithIndex(
  236. lambda s, it: [(s, sum(1 for _ in it))]
  237. ).collectAsMap()
  238. total = sum(split_size.values())
  239. if num > total:
  240. raise ValueError(f"not enough data to sample, own {total} but required {num}")
  241. # random the size of each split
  242. sampled_size = {}
  243. for split, size in split_size.items():
  244. sampled_size[split] = hypergeom.rvs(M=total, n=size, N=num)
  245. total = total - size
  246. num = num - sampled_size[split]
  247. return rdd.mapPartitionsWithIndex(
  248. _ReservoirSample(split_sample_size=sampled_size, seed=seed).func,
  249. preservesPartitioning=True,
  250. )
  251. class _ReservoirSample(RDDSamplerBase):
  252. def __init__(self, split_sample_size, seed):
  253. RDDSamplerBase.__init__(self, False, seed)
  254. self._split_sample_size = split_sample_size
  255. self._counter = 0
  256. self._sample = []
  257. def func(self, split, iterator):
  258. self.initRandomGenerator(split)
  259. size = self._split_sample_size[split]
  260. for obj in iterator:
  261. self._counter += 1
  262. if len(self._sample) < size:
  263. self._sample.append(obj)
  264. continue
  265. randint = self._random.randint(1, self._counter)
  266. if randint <= size:
  267. self._sample[randint - 1] = obj
  268. return self._sample
  269. def _filter(rdd, func):
  270. def _fn(x):
  271. return func(x[0], x[1])
  272. def _func(_, iterator):
  273. return filter(_fail_on_stopiteration(_fn), iterator)
  274. return rdd.mapPartitionsWithIndex(_func, preservesPartitioning=True)
  275. def _subtract_by_key(rdd, other):
  276. return rdd.subtractByKey(other, rdd.getNumPartitions())
  277. def _union(rdd, other, func):
  278. num_partition = max(rdd.getNumPartitions(), other.getNumPartitions())
  279. if func is None:
  280. return rdd.union(other).coalesce(num_partition)
  281. else:
  282. def _func(pair):
  283. iter1, iter2 = pair
  284. val1 = list(iter1)
  285. val2 = list(iter2)
  286. if not val1:
  287. return val2[0]
  288. if not val2:
  289. return val1[0]
  290. return func(val1[0], val2[0])
  291. return _map_value(rdd.cogroup(other, num_partition), _func)
  292. def _flat_map(rdd, func):
  293. def _fn(x):
  294. return func(x[0], x[1])
  295. def _func(_, iterator):
  296. return chain.from_iterable(map(_fail_on_stopiteration(_fn), iterator))
  297. return rdd.mapPartitionsWithIndex(_func, preservesPartitioning=False)