quantile_summaries.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import math
  19. import numpy as np
  20. from federatedml.util import consts, LOGGER
  21. """
  22. Structure of compressed object, for memory saving we use tuple (value, g, delta) in fate>=v1.8
  23. """
  24. """
  25. class Stats(object):
  26. def __init__(self, value, g: int, delta: int):
  27. self.value = value
  28. self.g = g
  29. self.delta = delta
  30. """
  31. class QuantileSummaries(object):
  32. def __init__(self, compress_thres=consts.DEFAULT_COMPRESS_THRESHOLD,
  33. head_size=consts.DEFAULT_HEAD_SIZE,
  34. error=consts.DEFAULT_RELATIVE_ERROR,
  35. abnormal_list=None):
  36. self.compress_thres = compress_thres
  37. self.head_size = head_size
  38. self.error = error
  39. self.head_sampled = []
  40. self.sampled = [] # list of Stats
  41. self.count = 0 # Total observations appeared
  42. self.missing_count = 0
  43. if abnormal_list is None:
  44. self.abnormal_list = []
  45. else:
  46. self.abnormal_list = abnormal_list
  47. # insert a number
  48. def insert(self, x):
  49. """
  50. Insert an observation of data. First store in a array buffer. If the buffer is full,
  51. do a batch insert. If the size of sampled list reach compress_thres, compress this list.
  52. Parameters
  53. ----------
  54. x : float
  55. The observation that prepare to insert
  56. """
  57. if x in self.abnormal_list or (isinstance(x, float) and np.isnan(x)):
  58. self.missing_count += 1
  59. return
  60. x = float(x)
  61. self.head_sampled.append(x)
  62. if len(self.head_sampled) >= self.head_size:
  63. self._insert_head_buffer()
  64. if len(self.sampled) >= self.compress_thres:
  65. self.compress()
  66. def _insert_head_buffer(self):
  67. if not len(self.head_sampled): # If empty
  68. return
  69. current_count = self.count
  70. sorted_head = sorted(self.head_sampled)
  71. head_len = len(sorted_head)
  72. sample_len = len(self.sampled)
  73. new_sampled = []
  74. sample_idx = 0
  75. ops_idx = 0
  76. while ops_idx < head_len:
  77. current_sample = sorted_head[ops_idx]
  78. while sample_idx < sample_len and self.sampled[sample_idx][0] <= current_sample:
  79. new_sampled.append(self.sampled[sample_idx])
  80. sample_idx += 1
  81. current_count += 1
  82. # If it is the first one to insert or if it is the last one
  83. if not new_sampled or (sample_idx == sample_len and
  84. ops_idx == head_len - 1):
  85. delta = 0
  86. else:
  87. # delta = math.floor(2 * self.error * current_count) - 1
  88. delta = math.floor(2 * self.error * current_count)
  89. new_sampled.append((current_sample, 1, delta))
  90. ops_idx += 1
  91. new_sampled += self.sampled[sample_idx:]
  92. self.sampled = new_sampled
  93. self.head_sampled = []
  94. self.count = current_count
  95. def compress(self):
  96. self._insert_head_buffer()
  97. # merge_threshold = math.floor(2 * self.error * self.count) - 1
  98. merge_threshold = 2 * self.error * self.count
  99. compressed = self._compress_immut(merge_threshold)
  100. self.sampled = compressed
  101. def merge(self, other):
  102. """
  103. merge current summeries with the other one.
  104. Parameters
  105. ----------
  106. other : QuantileSummaries
  107. The summaries to be merged
  108. """
  109. if other.head_sampled:
  110. # other._insert_head_buffer()
  111. other.compress()
  112. if self.head_sampled:
  113. # self._insert_head_buffer()
  114. self.compress()
  115. if other.count == 0:
  116. return self
  117. if self.count == 0:
  118. return other
  119. # merge two sorted array
  120. new_sample = []
  121. i, j = 0, 0
  122. self_sample_len = len(self.sampled)
  123. other_sample_len = len(other.sampled)
  124. while i < self_sample_len and j < other_sample_len:
  125. if self.sampled[i][0] < other.sampled[j][0]:
  126. new_sample.append(self.sampled[i])
  127. i += 1
  128. else:
  129. new_sample.append(other.sampled[j])
  130. j += 1
  131. new_sample += self.sampled[i:]
  132. new_sample += other.sampled[j:]
  133. res_summary = self.__class__(compress_thres=self.compress_thres,
  134. head_size=self.head_size,
  135. error=self.error,
  136. abnormal_list=self.abnormal_list)
  137. res_summary.count = self.count + other.count
  138. res_summary.missing_count = self.missing_count + other.missing_count
  139. res_summary.sampled = new_sample
  140. # self.sampled = new_sample
  141. # self.count += other.count
  142. # merge_threshold = math.floor(2 * self.error * self.count) - 1
  143. merge_threshold = 2 * self.error * res_summary.count
  144. res_summary.sampled = res_summary._compress_immut(merge_threshold)
  145. return res_summary
  146. def query(self, quantile):
  147. """
  148. Given the queried quantile, return the approximation guaranteed result
  149. Parameters
  150. ----------
  151. quantile : float [0.0, 1.0]
  152. The target quantile
  153. Returns
  154. -------
  155. float, the corresponding value result.
  156. """
  157. if self.head_sampled:
  158. # self._insert_head_buffer()
  159. self.compress()
  160. if quantile < 0 or quantile > 1:
  161. raise ValueError("Quantile should be in range [0.0, 1.0]")
  162. if self.count == 0:
  163. return 0
  164. if quantile <= self.error:
  165. return self.sampled[0][0]
  166. if quantile >= 1 - self.error:
  167. return self.sampled[-1][0]
  168. rank = math.ceil(quantile * self.count)
  169. target_error = math.ceil(self.error * self.count)
  170. min_rank = 0
  171. i = 1
  172. while i < len(self.sampled) - 1:
  173. cur_sample = self.sampled[i]
  174. min_rank += cur_sample[1]
  175. max_rank = min_rank + cur_sample[2]
  176. if max_rank - target_error <= rank <= min_rank + target_error:
  177. return cur_sample[0]
  178. i += 1
  179. return self.sampled[-1][0]
  180. def query_percentile_rate_list(self, percentile_rate_list):
  181. if self.head_sampled:
  182. self.compress()
  183. if np.min(percentile_rate_list) < 0 or np.max(percentile_rate_list) > 1:
  184. raise ValueError("Quantile should be in range [0.0, 1.0]")
  185. if self.count == 0:
  186. return [0] * len(percentile_rate_list)
  187. split_points = []
  188. i, j = 0, len(percentile_rate_list) - 1
  189. while i < len(percentile_rate_list) and percentile_rate_list[i] <= self.error:
  190. split_points.append(self.sampled[0][0])
  191. # split_points.append(self.sampled[0].value)
  192. i += 1
  193. while j >= 0 and percentile_rate_list[i] >= 1 - self.error:
  194. j -= 1
  195. k = 1
  196. min_rank = 0
  197. while i <= j:
  198. quantile = percentile_rate_list[i]
  199. rank = math.ceil(quantile * self.count)
  200. target_error = math.ceil(self.error * self.count)
  201. while k < len(self.sampled) - 1:
  202. # cur_sample = self.sampled[k]
  203. # min_rank += cur_sample.g
  204. # max_rank = min_rank + cur_sample.delta
  205. cur_sample_value = self.sampled[k][0]
  206. min_rank += self.sampled[k][1]
  207. max_rank = min_rank + self.sampled[k][2]
  208. if max_rank - target_error <= rank <= min_rank + target_error:
  209. split_points.append(cur_sample_value)
  210. min_rank -= self.sampled[k][1]
  211. break
  212. k += 1
  213. if k == len(self.sampled) - 1:
  214. # split_points.append(self.sampled[-1].value)
  215. split_points.append(self.sampled[-1][0])
  216. i += 1
  217. while j + 1 < len(percentile_rate_list):
  218. j += 1
  219. split_points.append(self.sampled[-1][0])
  220. assert len(percentile_rate_list) == len(split_points)
  221. return split_points
  222. def value_to_rank(self, value):
  223. min_rank, max_rank = 0, 0
  224. for sample in self.sampled:
  225. if sample[0] < value:
  226. min_rank += sample[1]
  227. max_rank = min_rank + sample[2]
  228. else:
  229. return (min_rank + max_rank) // 2
  230. return (min_rank + max_rank) // 2
  231. def query_value_list(self, values):
  232. """
  233. Given a sorted value list, return the rank of each element in this list
  234. """
  235. self.compress()
  236. res = []
  237. min_rank, max_rank = 0, 0
  238. idx = 0
  239. sample_idx = 0
  240. while sample_idx < len(self.sampled):
  241. v = values[idx]
  242. sample = self.sampled[sample_idx]
  243. if sample[0] <= v:
  244. min_rank += sample[1]
  245. max_rank = min_rank + sample[2]
  246. sample_idx += 1
  247. else:
  248. res.append((min_rank + max_rank) // 2)
  249. idx += 1
  250. if idx >= len(values):
  251. break
  252. while idx < len(values):
  253. res.append((min_rank + max_rank) // 2)
  254. idx += 1
  255. return res
  256. def _compress_immut(self, merge_threshold):
  257. if not self.sampled:
  258. return self.sampled
  259. res = []
  260. # Start from the last element
  261. head = self.sampled[-1]
  262. sum_g_delta = head[1] + head[2]
  263. i = len(self.sampled) - 2 # Do not merge the last element
  264. while i >= 1:
  265. this_sample = self.sampled[i]
  266. if this_sample[1] + sum_g_delta < merge_threshold:
  267. head = (head[0], head[1] + this_sample[1], head[2])
  268. sum_g_delta += this_sample[1]
  269. else:
  270. res.append(head)
  271. head = this_sample
  272. sum_g_delta = head[1] + head[2]
  273. i -= 1
  274. res.append(head)
  275. # If head of current sample is smaller than this new res's head
  276. # Add current head into res
  277. current_head = self.sampled[0]
  278. if current_head[0] <= head[0] and len(self.sampled) > 1:
  279. res.append(current_head)
  280. # Python do not support prepend, thus, use reverse instead
  281. res.reverse()
  282. return res
  283. class SparseQuantileSummaries(QuantileSummaries):
  284. def __init__(self, compress_thres=consts.DEFAULT_COMPRESS_THRESHOLD,
  285. head_size=consts.DEFAULT_HEAD_SIZE,
  286. error=consts.DEFAULT_RELATIVE_ERROR,
  287. abnormal_list=None):
  288. super(SparseQuantileSummaries, self).__init__(compress_thres, head_size, error, abnormal_list)
  289. # Compare with the sparse point, static the number of each part.
  290. self.smaller_num = 0
  291. self.bigger_num = 0
  292. self._total_count = 0
  293. def set_total_count(self, total_count):
  294. self._total_count = total_count
  295. return self
  296. @property
  297. def summary_count(self):
  298. return self._total_count - self.missing_count
  299. def insert(self, x):
  300. if x in self.abnormal_list or np.isnan(x):
  301. self.missing_count += 1
  302. return
  303. if x < consts.FLOAT_ZERO:
  304. self.smaller_num += 1
  305. elif x >= consts.FLOAT_ZERO:
  306. self.bigger_num += 1
  307. super(SparseQuantileSummaries, self).insert(x)
  308. def query(self, quantile):
  309. if self.zero_lower_bound < quantile < self.zero_upper_bound:
  310. return 0.0
  311. non_zero_quantile = self._convert_query_percentile(quantile)
  312. result = super(SparseQuantileSummaries, self).query(non_zero_quantile)
  313. return result
  314. def query_percentile_rate_list(self, percentile_rate_list):
  315. result = []
  316. non_zero_quantile_list = list()
  317. for quantile in percentile_rate_list:
  318. if self.zero_lower_bound < quantile < self.zero_upper_bound:
  319. result.append(0.0)
  320. else:
  321. non_zero_quantile_list.append(self._convert_query_percentile(quantile))
  322. if non_zero_quantile_list:
  323. result += super(SparseQuantileSummaries, self).query_percentile_rate_list(non_zero_quantile_list)
  324. return result
  325. def value_to_rank(self, value):
  326. quantile_rank = super().value_to_rank(value)
  327. zeros_count = self.zero_counts
  328. if value > 0:
  329. return quantile_rank + zeros_count
  330. elif value < 0:
  331. return quantile_rank
  332. else:
  333. return quantile_rank + zeros_count // 2
  334. def merge(self, other):
  335. assert isinstance(other, SparseQuantileSummaries)
  336. res_summary = super(SparseQuantileSummaries, self).merge(other)
  337. res_summary.smaller_num = self.smaller_num + other.smaller_num
  338. res_summary.bigger_num = self.bigger_num + other.bigger_num
  339. return res_summary
  340. def _convert_query_percentile(self, quantile):
  341. zeros_count = self.zero_counts
  342. if zeros_count == 0:
  343. return quantile
  344. if quantile <= self.zero_lower_bound:
  345. return ((self._total_count - self.missing_count) / self.count) * quantile
  346. return (quantile - self.zero_upper_bound + self.zero_lower_bound) / (
  347. 1 - self.zero_upper_bound + self.zero_lower_bound)
  348. @property
  349. def zero_lower_bound(self):
  350. if self.smaller_num == 0:
  351. return 0.0
  352. return self.smaller_num / (self._total_count - self.missing_count)
  353. @property
  354. def zero_upper_bound(self):
  355. if self.bigger_num == 0:
  356. return self._total_count - self.missing_count
  357. return (self.smaller_num + self.zero_counts) / (self._total_count - self.missing_count)
  358. @property
  359. def zero_counts(self):
  360. return self._total_count - self.smaller_num - self.bigger_num - self.missing_count
  361. def query_value_list(self, values):
  362. summary_ranks = super().query_value_list(values)
  363. res = []
  364. for v, r in zip(values, summary_ranks):
  365. if v == 0:
  366. res.append(self.smaller_num)
  367. elif v < 0:
  368. res.append(r)
  369. else:
  370. res.append(r + self.zero_counts)
  371. return res
  372. def quantile_summary_factory(is_sparse, param_dict):
  373. if is_sparse:
  374. return SparseQuantileSummaries(**param_dict)
  375. else:
  376. return QuantileSummaries(**param_dict)