hetero_binning_guest.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import functools
  18. import numpy as np
  19. from federatedml.cipher_compressor.packer import GuestIntegerPacker
  20. from federatedml.feature.binning.iv_calculator import IvCalculator
  21. from federatedml.secureprotol.encrypt_mode import EncryptModeCalculator
  22. from federatedml.feature.binning.optimal_binning.optimal_binning import OptimalBinning
  23. from federatedml.feature.hetero_feature_binning.base_feature_binning import BaseFeatureBinning
  24. from federatedml.secureprotol import PaillierEncrypt
  25. from federatedml.secureprotol.fate_paillier import PaillierEncryptedNumber
  26. from federatedml.statistic import data_overview
  27. from federatedml.statistic import statics
  28. from federatedml.util import LOGGER
  29. from federatedml.util import consts
  30. class HeteroFeatureBinningGuest(BaseFeatureBinning):
  31. def __init__(self):
  32. super().__init__()
  33. self._packer: GuestIntegerPacker = None
  34. def fit(self, data_instances):
  35. """
  36. Apply binning method for both data instances in local party as well as the other one. Afterwards, calculate
  37. the specific metric value for specific columns. Currently, iv is support for binary labeled data only.
  38. """
  39. LOGGER.info("Start feature binning fit and transform")
  40. self._abnormal_detection(data_instances)
  41. # self._parse_cols(data_instances)
  42. self._setup_bin_inner_param(data_instances, self.model_param)
  43. split_points_obj = None
  44. if self.model_param.method == consts.OPTIMAL:
  45. has_missing_value = self.iv_calculator.check_containing_missing_value(data_instances)
  46. for idx in self.bin_inner_param.bin_indexes:
  47. if idx in has_missing_value:
  48. raise ValueError(f"Optimal Binning do not support missing value now.")
  49. if self.model_param.split_points_by_col_name or self.model_param.split_points_by_index:
  50. split_points = self._get_manual_split_points(data_instances)
  51. self.use_manual_split_points = True
  52. for col_name, sp in split_points.items():
  53. self.binning_obj.bin_results.put_col_split_points(col_name, sp)
  54. else:
  55. split_points = self.binning_obj.fit_split_points(data_instances)
  56. split_points_obj = self.binning_obj.bin_results
  57. if self.model_param.skip_static:
  58. self.transform_data(data_instances)
  59. return self.data_output
  60. label_counts_dict, label_counts, label_table = self.stat_label(data_instances)
  61. self.bin_result = self.cal_local_iv(data_instances, split_points, label_counts, label_table)
  62. if self.model_param.method == consts.OPTIMAL and split_points_obj is not None:
  63. # LOGGER.debug(f"set optimal metric array")
  64. self.set_optimal_metric_array(split_points_obj.all_optimal_metric)
  65. if self.model_param.local_only:
  66. self.transform_data(data_instances)
  67. self.set_summary(self.bin_result.summary())
  68. return self.data_output
  69. self.host_results = self.federated_iv(
  70. data_instances=data_instances,
  71. label_table=label_table,
  72. result_counts=label_counts_dict,
  73. label_elements=self.labels,
  74. label_counts=label_counts)
  75. total_summary = self.bin_result.summary()
  76. for host_res in self.host_results:
  77. total_summary = self._merge_summary(total_summary, host_res.summary())
  78. self.set_schema(data_instances)
  79. self.transform_data(data_instances)
  80. LOGGER.info("Finish feature binning fit and transform")
  81. self.set_summary(total_summary)
  82. return self.data_output
  83. def transform(self, data_instances):
  84. if self.model_param.skip_static:
  85. self.transform_data(data_instances)
  86. return self.data_output
  87. has_label = True
  88. if data_instances.first()[1].label is None:
  89. has_label = False
  90. self.transfer_variable.transform_stage_has_label.remote(has_label,
  91. role=consts.HOST,
  92. idx=-1)
  93. if not has_label:
  94. self.transform_data(data_instances)
  95. return self.data_output
  96. self._setup_bin_inner_param(data_instances, self.model_param)
  97. label_counts_dict, label_counts, label_table = self.stat_label(data_instances)
  98. if (set(self.labels) & set(label_counts_dict)) != set(label_counts_dict):
  99. raise ValueError(f"Label {set(self.labels) - set(label_counts_dict)} can not be recognized")
  100. split_points = self.binning_obj.bin_results.all_split_points
  101. self.transform_bin_result = self.cal_local_iv(data_instances, split_points, label_counts, label_table)
  102. if self.model_param.local_only:
  103. self.transform_data(data_instances)
  104. self.set_summary(self.bin_result.summary())
  105. return self.data_output
  106. self.transform_host_results = self.federated_iv(data_instances=data_instances,
  107. label_table=label_table,
  108. result_counts=label_counts_dict,
  109. label_elements=self.labels,
  110. label_counts=label_counts)
  111. total_summary = self.transform_bin_result.summary()
  112. for host_res in self.transform_host_results:
  113. total_summary = self._merge_summary(total_summary, host_res.summary())
  114. self.set_schema(data_instances)
  115. self.transform_data(data_instances)
  116. LOGGER.info("Finish feature binning fit and transform")
  117. self.set_summary(total_summary)
  118. return self.data_output
  119. def stat_label(self, data_instances):
  120. label_counts_dict = data_overview.get_label_count(data_instances)
  121. if len(label_counts_dict) > 2:
  122. if self.model_param.method == consts.OPTIMAL:
  123. raise ValueError("Have not supported optimal binning in multi-class data yet")
  124. if self._stage == "fit":
  125. self.labels = list(label_counts_dict.keys())
  126. self.labels.sort()
  127. self.labels.reverse()
  128. label_counts = [label_counts_dict.get(k, 0) for k in self.labels]
  129. label_table = IvCalculator.convert_label(data_instances, self.labels)
  130. return label_counts_dict, label_counts, label_table
  131. def cal_local_iv(self, data_instances, split_points, label_counts, label_table):
  132. bin_result = self.iv_calculator.cal_local_iv(data_instances=data_instances,
  133. split_points=split_points,
  134. labels=self.labels,
  135. label_counts=label_counts,
  136. bin_cols_map=self.bin_inner_param.get_need_cal_iv_cols_map(),
  137. label_table=label_table)
  138. return bin_result
  139. def federated_iv(self, data_instances, label_table, result_counts, label_elements, label_counts):
  140. if self.model_param.encrypt_param.method == consts.PAILLIER:
  141. paillier_encryptor = PaillierEncrypt()
  142. paillier_encryptor.generate_key(self.model_param.encrypt_param.key_length)
  143. else:
  144. raise NotImplementedError("encrypt method not supported yet")
  145. self._packer = GuestIntegerPacker(pack_num=len(self.labels), pack_num_range=label_counts,
  146. encrypter=paillier_encryptor)
  147. converted_label_table = label_table.mapValues(lambda x: [int(i) for i in x])
  148. encrypted_label_table = self._packer.pack_and_encrypt(converted_label_table)
  149. self.transfer_variable.encrypted_label.remote(encrypted_label_table,
  150. role=consts.HOST,
  151. idx=-1)
  152. encrypted_bin_sum_infos = self.transfer_variable.encrypted_bin_sum.get(idx=-1)
  153. encrypted_bin_infos = self.transfer_variable.optimal_info.get(idx=-1)
  154. LOGGER.info("Get encrypted_bin_sum from host")
  155. host_results = []
  156. for host_idx, encrypted_bin_info in enumerate(encrypted_bin_infos):
  157. host_party_id = self.component_properties.host_party_idlist[host_idx]
  158. encrypted_bin_sum = encrypted_bin_sum_infos[host_idx]
  159. # assert 1 == 2, f"encrypted_bin_sum: {list(encrypted_bin_sum.collect())}"
  160. result_counts_table = self._packer.decrypt_cipher_package_and_unpack(encrypted_bin_sum)
  161. # LOGGER.debug(f"unpack result: {result_counts_table.first()}")
  162. bin_result = self.cal_bin_results(data_instances=data_instances,
  163. host_idx=host_idx,
  164. encrypted_bin_info=encrypted_bin_info,
  165. result_counts_table=result_counts_table,
  166. result_counts=result_counts,
  167. label_elements=label_elements)
  168. bin_result.set_role_party(role=consts.HOST, party_id=host_party_id)
  169. host_results.append(bin_result)
  170. return host_results
  171. def host_optimal_binning(self, data_instances, host_idx, encrypted_bin_info, result_counts, category_names):
  172. optimal_binning_params = encrypted_bin_info['optimal_params']
  173. host_model_params = copy.deepcopy(self.model_param)
  174. host_model_params.bin_num = optimal_binning_params.get('bin_num')
  175. host_model_params.optimal_binning_param.metric_method = optimal_binning_params.get('metric_method')
  176. host_model_params.optimal_binning_param.mixture = optimal_binning_params.get('mixture')
  177. host_model_params.optimal_binning_param.max_bin_pct = optimal_binning_params.get('max_bin_pct')
  178. host_model_params.optimal_binning_param.min_bin_pct = optimal_binning_params.get('min_bin_pct')
  179. event_total, non_event_total = self.get_histogram(data_instances)
  180. result_counts = dict(result_counts.collect())
  181. optimal_binning_cols = {x: y for x, y in result_counts.items() if x not in category_names}
  182. host_binning_obj = OptimalBinning(params=host_model_params, abnormal_list=self.binning_obj.abnormal_list)
  183. host_binning_obj.event_total = event_total
  184. host_binning_obj.non_event_total = non_event_total
  185. host_binning_obj = self.optimal_binning_sync(host_binning_obj, optimal_binning_cols, data_instances.count(),
  186. data_instances.partitions,
  187. host_idx)
  188. return host_binning_obj
  189. def cal_bin_results(self, data_instances, host_idx, encrypted_bin_info, result_counts_table,
  190. result_counts, label_elements):
  191. host_bin_methods = encrypted_bin_info['bin_method']
  192. category_names = encrypted_bin_info['category_names']
  193. result_counts_dict = dict(result_counts_table.collect())
  194. host_party_id = self.component_properties.host_party_idlist[host_idx]
  195. if host_bin_methods == consts.OPTIMAL and self._stage == "fit":
  196. if len(result_counts) > 2:
  197. raise ValueError("Have not supported optimal binning in multi-class data yet")
  198. host_binning_obj = self.host_optimal_binning(data_instances, host_idx,
  199. encrypted_bin_info, result_counts_table,
  200. category_names)
  201. optimal_counts = {}
  202. for col_name, bucket_list in host_binning_obj.bucket_lists.items():
  203. optimal_counts[col_name] = [np.array([b.event_count, b.non_event_count]) for b in bucket_list]
  204. for col_name, counts in result_counts_dict.items():
  205. if col_name in category_names:
  206. optimal_counts[col_name] = counts
  207. # LOGGER.debug(f"optimal_counts: {optimal_counts}")
  208. bin_res = self.iv_calculator.cal_iv_from_counts(optimal_counts, labels=label_elements,
  209. role=consts.HOST, party_id=host_party_id)
  210. else:
  211. bin_res = self.iv_calculator.cal_iv_from_counts(result_counts_table,
  212. label_elements,
  213. role=consts.HOST,
  214. party_id=host_party_id)
  215. return bin_res
  216. @staticmethod
  217. def convert_decompress_format(encrypted_bin_sum):
  218. """
  219. Parameters
  220. ----------
  221. encrypted_bin_sum : dict.
  222. {"keys": ['x1', 'x2' ...],
  223. "event_counts": [...],
  224. "non_event_counts": [...],
  225. bin_num": [...]
  226. }
  227. returns
  228. -------
  229. {'x1': [[event_count, non_event_count], [event_count, non_event_count] ... ],
  230. 'x2': [[event_count, non_event_count], [event_count, non_event_count] ... ],
  231. ...
  232. }
  233. """
  234. result = {}
  235. start = 0
  236. event_counts = [int(x) for x in encrypted_bin_sum['event_counts']]
  237. non_event_counts = [int(x) for x in encrypted_bin_sum['non_event_counts']]
  238. for idx, k in enumerate(encrypted_bin_sum["keys"]):
  239. bin_num = encrypted_bin_sum["bin_nums"][idx]
  240. result[k] = list(zip(event_counts[start: start + bin_num], non_event_counts[start: start + bin_num]))
  241. start += bin_num
  242. assert start == len(event_counts) == len(non_event_counts), \
  243. f"Length of event/non-event does not match " \
  244. f"with bin_num sums, all_counts: {start}, length of event count: {len(event_counts)}," \
  245. f"length of non_event_counts: {len(non_event_counts)}"
  246. return result
  247. @staticmethod
  248. def _merge_summary(summary_1, summary_2):
  249. def merge_single_label(s1, s2):
  250. res = {}
  251. for k, v in s1.items():
  252. if k == 'iv':
  253. v.extend(s2[k])
  254. v = sorted(v, key=lambda p: p[1], reverse=True)
  255. else:
  256. v.update(s2[k])
  257. res[k] = v
  258. return res
  259. res = {}
  260. for label, s1 in summary_1.items():
  261. s2 = summary_2.get(label)
  262. res[label] = merge_single_label(s1, s2)
  263. return res
  264. @staticmethod
  265. def encrypt(x, cipher):
  266. if not isinstance(x, np.ndarray):
  267. return cipher.encrypt(x)
  268. res = []
  269. for idx, value in enumerate(x):
  270. res.append(cipher.encrypt(value))
  271. return np.array(res)
  272. @staticmethod
  273. def __decrypt_bin_sum(encrypted_bin_sum, cipher):
  274. def decrypt(values):
  275. res = []
  276. for counts in values:
  277. for idx, c in enumerate(counts):
  278. if isinstance(c, PaillierEncryptedNumber):
  279. counts[idx] = cipher.decrypt(c)
  280. res.append(counts)
  281. return res
  282. return encrypted_bin_sum.mapValues(decrypt)
  283. @staticmethod
  284. def load_data(data_instance):
  285. data_instance = copy.deepcopy(data_instance)
  286. # Here suppose this is a binary question and the event label is 1
  287. if data_instance.label != 1:
  288. data_instance.label = 0
  289. return data_instance
  290. def optimal_binning_sync(self, host_binning_obj, result_counts, sample_count, partitions, host_idx):
  291. LOGGER.debug("Start host party optimal binning train")
  292. bucket_table = host_binning_obj.bin_sum_to_bucket_list(result_counts, partitions)
  293. host_binning_obj.fit_buckets(bucket_table, sample_count)
  294. encoded_split_points = host_binning_obj.bin_results.all_split_points
  295. self.transfer_variable.bucket_idx.remote(encoded_split_points,
  296. role=consts.HOST,
  297. idx=host_idx)
  298. return host_binning_obj
  299. @staticmethod
  300. def get_histogram(data_instances):
  301. static_obj = statics.MultivariateStatisticalSummary(data_instances, cols_index=-1)
  302. label_historgram = static_obj.get_label_histogram()
  303. event_total = label_historgram.get(1, 0)
  304. non_event_total = label_historgram.get(0, 0)
  305. if event_total == 0 or non_event_total == 0:
  306. LOGGER.warning(f"event_total or non_event_total might have errors, event_total: {event_total},"
  307. f" non_event_total: {non_event_total}")
  308. return event_total, non_event_total