boosting_tree_predict.py 21 KB


  1. import functools
  2. import numpy as np
  3. import random
  4. from typing import List
  5. from federatedml.util import consts
  6. from federatedml.secureprotol import PaillierEncrypt
  7. from federatedml.ensemble.basic_algorithms import HeteroDecisionTreeGuest, HeteroDecisionTreeHost, \
  8. HeteroFastDecisionTreeGuest, HeteroFastDecisionTreeHost
  9. from federatedml.ensemble.basic_algorithms.decision_tree.tree_core.decision_tree import DecisionTree, Node
  10. from federatedml.util import LOGGER
  11. from federatedml.transfer_variable.transfer_class.hetero_secure_boosting_predict_transfer_variable import \
  12. HeteroSecureBoostTransferVariable
  13. """
  14. Hetero guest predict utils
  15. """
  16. def generate_leaf_pos_dict(x, tree_num, np_int_type=np.int8):
  17. """
  18. x: just occupy the first parameter position
  19. return: a numpy array record sample pos, and a counter counting how many trees reach a leaf node
  20. """
  21. node_pos = np.zeros(tree_num, dtype=np_int_type)
  22. reach_leaf_node = np.zeros(tree_num, dtype=np.bool)
  23. return {'node_pos': node_pos, 'reach_leaf_node': reach_leaf_node}
  24. def guest_traverse_a_tree(tree: HeteroDecisionTreeGuest, sample, cur_node_idx):
  25. reach_leaf = False
  26. # only need nid here, predict state is not needed
  27. rs = tree.traverse_tree(tree_=tree.tree_node, data_inst=sample, predict_state=(cur_node_idx, -1),
  28. decoder=tree.decode, sitename=tree.sitename, use_missing=tree.use_missing,
  29. split_maskdict=tree.split_maskdict, missing_dir_maskdict=tree.missing_dir_maskdict,
  30. zero_as_missing=tree.zero_as_missing, return_leaf_id=True)
  31. if not isinstance(rs, tuple):
  32. reach_leaf = True
  33. leaf_id = rs
  34. return leaf_id, reach_leaf
  35. else:
  36. cur_node_idx = rs[0]
  37. return cur_node_idx, reach_leaf
  38. def guest_traverse_trees(node_pos, sample, trees: List[HeteroDecisionTreeGuest]):
  39. if node_pos['reach_leaf_node'].all():
  40. return node_pos
  41. for t_idx, tree in enumerate(trees):
  42. cur_node_idx = node_pos['node_pos'][t_idx]
  43. # reach leaf
  44. if cur_node_idx == -1:
  45. continue
  46. rs, reach_leaf = guest_traverse_a_tree(tree, sample, cur_node_idx)
  47. if reach_leaf:
  48. node_pos['reach_leaf_node'][t_idx] = True
  49. node_pos['node_pos'][t_idx] = rs
  50. return node_pos
  51. def merge_predict_pos(node_pos1, node_pos2):
  52. pos_arr1 = node_pos1['node_pos']
  53. pos_arr2 = node_pos2['node_pos']
  54. stack_arr = np.stack([pos_arr1, pos_arr2])
  55. node_pos1['node_pos'] = np.max(stack_arr, axis=0)
  56. return node_pos1
  57. def add_y_hat(leaf_pos, init_score, learning_rate, trees: List[HeteroDecisionTreeGuest], multi_class_num=None):
  58. # finally node pos will hold weights
  59. weights = []
  60. for leaf_idx, tree in zip(leaf_pos, trees):
  61. weights.append(tree.tree_node[int(leaf_idx)].weight)
  62. weights = np.array(weights)
  63. if multi_class_num > 2:
  64. weights = weights.reshape((-1, multi_class_num))
  65. return np.sum(weights * learning_rate, axis=0) + init_score
  66. def get_predict_scores(
  67. leaf_pos,
  68. learning_rate,
  69. init_score,
  70. trees: List[HeteroDecisionTreeGuest],
  71. multi_class_num=-1,
  72. predict_cache=None):
  73. if predict_cache:
  74. init_score = 0 # prevent init_score re-add
  75. predict_func = functools.partial(add_y_hat,
  76. learning_rate=learning_rate, init_score=init_score, trees=trees,
  77. multi_class_num=multi_class_num)
  78. predict_result = leaf_pos.mapValues(predict_func)
  79. if predict_cache:
  80. predict_result = predict_result.join(predict_cache, lambda v1, v2: v1 + v2)
  81. return predict_result
  82. def save_leaf_pos_helper(v1, v2):
  83. reach_leaf_idx = v2['reach_leaf_node']
  84. select_idx = reach_leaf_idx & (v2['node_pos'] != -1) # reach leaf and are not recorded( if recorded idx is -1)
  85. v1[select_idx] = v2['node_pos'][select_idx]
  86. return v1
  87. def mask_leaf_pos(v):
  88. reach_leaf_idx = v['reach_leaf_node']
  89. v['node_pos'][reach_leaf_idx] = -1
  90. return v
  91. def save_leaf_pos_and_mask_leaf_pos(node_pos_tb, final_leaf_pos):
  92. # save leaf pos
  93. saved_leaf_pos = final_leaf_pos.join(node_pos_tb, save_leaf_pos_helper)
  94. rest_part = final_leaf_pos.subtractByKey(saved_leaf_pos)
  95. final_leaf_pos = saved_leaf_pos.union(rest_part)
  96. # mask leaf pos
  97. node_pos_tb = node_pos_tb.mapValues(mask_leaf_pos)
  98. return node_pos_tb, final_leaf_pos
  99. def merge_leaf_pos(pos1, pos2):
  100. return pos1 + pos2
  101. def traverse_guest_local_trees(node_pos, sample, trees: List[HeteroFastDecisionTreeGuest]):
  102. """
  103. in mix mode, a sample can reach leaf directly
  104. """
  105. for t_idx, tree in enumerate(trees):
  106. cur_node_idx = node_pos[t_idx]
  107. if not tree.use_guest_feat_only_predict_mode:
  108. continue
  109. rs, reach_leaf = guest_traverse_a_tree(tree, sample, cur_node_idx)
  110. node_pos[t_idx] = rs
  111. return node_pos
  112. """
  113. Hetero guest predict function
  114. """
  115. def get_dtype(max_int):
  116. if max_int < (2**8) / 2:
  117. return np.int8
  118. elif max_int < (2**16) / 2:
  119. return np.int16
  120. else:
  121. return np.int64
  122. def sbt_guest_predict(data_inst, transfer_var: HeteroSecureBoostTransferVariable,
  123. trees: List[HeteroDecisionTreeGuest], learning_rate, init_score, booster_dim,
  124. predict_cache=None, pred_leaf=False):
  125. tree_num = len(trees)
  126. max_depth = trees[0].max_depth
  127. max_int = 2 ** max_depth
  128. dtype = get_dtype(max_int)
  129. LOGGER.debug('chosen np dtype is {}'.format(dtype))
  130. generate_func = functools.partial(generate_leaf_pos_dict, tree_num=tree_num, np_int_type=dtype)
  131. node_pos_tb = data_inst.mapValues(generate_func) # record node pos
  132. final_leaf_pos = data_inst.mapValues(lambda x: np.zeros(tree_num, dtype=dtype) + np.nan) # record final leaf pos
  133. traverse_func = functools.partial(guest_traverse_trees, trees=trees)
  134. comm_round = 0
  135. while True:
  136. # LOGGER.info('cur predict round is {}'.format(comm_round))
  137. node_pos_tb = node_pos_tb.join(data_inst, traverse_func)
  138. node_pos_tb, final_leaf_pos = save_leaf_pos_and_mask_leaf_pos(node_pos_tb, final_leaf_pos)
  139. # remove sample that reaches leaves of all trees
  140. reach_leaf_samples = node_pos_tb.filter(lambda key, value: value['reach_leaf_node'].all())
  141. node_pos_tb = node_pos_tb.subtractByKey(reach_leaf_samples)
  142. if node_pos_tb.count() == 0:
  143. transfer_var.predict_stop_flag.remote(True, idx=-1, suffix=(comm_round,))
  144. break
  145. transfer_var.predict_stop_flag.remote(False, idx=-1, suffix=(comm_round,))
  146. transfer_var.guest_predict_data.remote(node_pos_tb, idx=-1, suffix=(comm_round,))
  147. host_pos_tbs = transfer_var.host_predict_data.get(idx=-1, suffix=(comm_round,))
  148. for host_pos_tb in host_pos_tbs:
  149. node_pos_tb = node_pos_tb.join(host_pos_tb, merge_predict_pos)
  150. comm_round += 1
  151. if pred_leaf: # return leaf position only
  152. return final_leaf_pos
  153. else: # get final predict scores from leaf pos
  154. predict_result = get_predict_scores(leaf_pos=final_leaf_pos, learning_rate=learning_rate,
  155. init_score=init_score, trees=trees,
  156. multi_class_num=booster_dim, predict_cache=predict_cache)
  157. return predict_result
  158. def mix_sbt_guest_predict(data_inst, transfer_var: HeteroSecureBoostTransferVariable,
  159. trees: List[HeteroDecisionTreeGuest], learning_rate, init_score, booster_dim,
  160. predict_cache=None, pred_leaf=False):
  161. LOGGER.info('running mix mode predict')
  162. tree_num = len(trees)
  163. node_pos = data_inst.mapValues(lambda x: np.zeros(tree_num, dtype=np.int64))
  164. # traverse local trees
  165. traverse_func = functools.partial(traverse_guest_local_trees, trees=trees)
  166. guest_leaf_pos = node_pos.join(data_inst, traverse_func)
  167. # get leaf node from other host parties
  168. host_leaf_pos_list = transfer_var.host_predict_data.get(idx=-1)
  169. for host_leaf_pos in host_leaf_pos_list:
  170. guest_leaf_pos = guest_leaf_pos.join(host_leaf_pos, merge_leaf_pos)
  171. if pred_leaf: # predict leaf, return leaf position only
  172. return guest_leaf_pos
  173. else:
  174. predict_result = get_predict_scores(leaf_pos=guest_leaf_pos, learning_rate=learning_rate,
  175. init_score=init_score, trees=trees,
  176. multi_class_num=booster_dim, predict_cache=predict_cache)
  177. return predict_result
  178. """
  179. Hetero host predict utils
  180. """
  181. def host_traverse_a_tree(tree: HeteroDecisionTreeHost, sample, cur_node_idx):
  182. nid, _ = tree.traverse_tree(predict_state=(cur_node_idx, -1), data_inst=sample,
  183. decoder=tree.decode, split_maskdict=tree.split_maskdict,
  184. missing_dir_maskdict=tree.missing_dir_maskdict, sitename=tree.sitename,
  185. tree_=tree.tree_node, zero_as_missing=tree.zero_as_missing,
  186. use_missing=tree.use_missing)
  187. return nid, _
  188. def host_traverse_trees(sample, leaf_pos, trees: List[HeteroDecisionTreeHost]):
  189. for t_idx, tree in enumerate(trees):
  190. cur_node_idx = leaf_pos['node_pos'][t_idx]
  191. # idx is set as -1 when a sample reaches leaf
  192. if cur_node_idx == -1:
  193. continue
  194. nid, _ = host_traverse_a_tree(tree, sample, cur_node_idx)
  195. leaf_pos['node_pos'][t_idx] = nid
  196. return leaf_pos
  197. def traverse_host_local_trees(node_pos, sample, trees: List[HeteroFastDecisionTreeHost]):
  198. """
  199. in mix mode, a sample can reach leaf directly
  200. """
  201. for i in range(len(trees)):
  202. tree = trees[i]
  203. if len(tree.tree_node) == 0: # this tree belongs to other party because it has no tree node
  204. continue
  205. leaf_id = tree.host_local_traverse_tree(sample, tree.tree_node, use_missing=tree.use_missing,
  206. zero_as_missing=tree.zero_as_missing)
  207. node_pos[i] = leaf_id
  208. return node_pos
  209. """
  210. Hetero host predict function
  211. """
  212. def sbt_host_predict(data_inst, transfer_var: HeteroSecureBoostTransferVariable, trees: List[HeteroDecisionTreeHost]):
  213. comm_round = 0
  214. traverse_func = functools.partial(host_traverse_trees, trees=trees)
  215. while True:
  216. LOGGER.debug('cur predict round is {}'.format(comm_round))
  217. stop_flag = transfer_var.predict_stop_flag.get(idx=0, suffix=(comm_round,))
  218. if stop_flag:
  219. break
  220. guest_node_pos = transfer_var.guest_predict_data.get(idx=0, suffix=(comm_round,))
  221. host_node_pos = data_inst.join(guest_node_pos, traverse_func)
  222. if guest_node_pos.count() != host_node_pos.count():
  223. raise ValueError('sample count mismatch: guest table {}, host table {}'.format(guest_node_pos.count(),
  224. host_node_pos.count()))
  225. transfer_var.host_predict_data.remote(host_node_pos, idx=-1, suffix=(comm_round,))
  226. comm_round += 1
  227. def mix_sbt_host_predict(data_inst, transfer_var: HeteroSecureBoostTransferVariable,
  228. trees: List[HeteroDecisionTreeHost]):
  229. LOGGER.info('running mix mode predict')
  230. tree_num = len(trees)
  231. node_pos = data_inst.mapValues(lambda x: np.zeros(tree_num, dtype=np.int64))
  232. local_traverse_func = functools.partial(traverse_host_local_trees, trees=trees)
  233. leaf_pos = node_pos.join(data_inst, local_traverse_func)
  234. transfer_var.host_predict_data.remote(leaf_pos, idx=0, role=consts.GUEST)
  235. """
  236. Fed-EINI predict func
  237. """
  238. def get_leaf_idx_map(trees):
  239. id_pos_map_list = []
  240. for tree in trees:
  241. array_idx = 0
  242. id_pos_map = {}
  243. for node in tree.tree_node:
  244. if node.is_leaf:
  245. id_pos_map[node.id] = array_idx
  246. array_idx += 1
  247. id_pos_map_list.append(id_pos_map)
  248. return id_pos_map_list
  249. def go_to_children_branches(data_inst, tree_node, tree, sitename: str, candidate_list: List):
  250. if tree_node.is_leaf:
  251. candidate_list.append(tree_node)
  252. else:
  253. tree_node_list = tree.tree_node
  254. if tree_node.sitename != sitename:
  255. go_to_children_branches(data_inst, tree_node_list[tree_node.left_nodeid],
  256. tree, sitename, candidate_list)
  257. go_to_children_branches(data_inst, tree_node_list[tree_node.right_nodeid],
  258. tree, sitename, candidate_list)
  259. else:
  260. next_layer_node_id = tree.go_next_layer(tree_node, data_inst, use_missing=tree.use_missing,
  261. zero_as_missing=tree.zero_as_missing, decoder=tree.decode,
  262. split_maskdict=tree.split_maskdict,
  263. missing_dir_maskdict=tree.missing_dir_maskdict,
  264. bin_sparse_point=None
  265. )
  266. go_to_children_branches(data_inst, tree_node_list[next_layer_node_id], tree, sitename, candidate_list)
  267. def generate_leaf_candidates_guest(data_inst, sitename, trees, node_pos_map_list,
  268. init_score, learning_rate, booster_dim):
  269. candidate_nodes_of_all_tree = []
  270. if booster_dim > 2:
  271. epoch_num = len(trees) // booster_dim
  272. else:
  273. epoch_num = len(trees)
  274. init_score = init_score / epoch_num
  275. score_idx = 0
  276. for tree, node_pos_map in zip(trees, node_pos_map_list):
  277. if booster_dim > 2:
  278. tree_init_score = init_score[score_idx]
  279. score_idx += 1
  280. if score_idx == booster_dim:
  281. score_idx = 0
  282. else:
  283. tree_init_score = init_score
  284. candidate_list = []
  285. go_to_children_branches(data_inst, tree.tree_node[0], tree, sitename, candidate_list)
  286. # check if it is mo tree:
  287. if len(candidate_list) < 1:
  288. raise ValueError('incorrect candidate list length,: {}'.format(len(candidate_list)))
  289. node = candidate_list[0]
  290. result_vec = np.zeros(len(node_pos_map))
  291. if isinstance(node.weight, np.ndarray):
  292. if len(node.weight) > 1:
  293. result_vec = [np.array([0 for i in range(len(node.weight))]) for i in range(len(node_pos_map))]
  294. for node in candidate_list:
  295. result_vec[node_pos_map[node.id]] = node.weight * learning_rate + tree_init_score
  296. candidate_nodes_of_all_tree.extend(result_vec)
  297. return np.array(candidate_nodes_of_all_tree)
  298. def EINI_guest_predict(data_inst, trees: List[HeteroDecisionTreeGuest], learning_rate, init_score, booster_dim,
  299. encrypt_key_length, transfer_var: HeteroSecureBoostTransferVariable,
  300. sitename=None, party_list=None, predict_cache=None, pred_leaf=False):
  301. if sitename is None:
  302. raise ValueError('input sitename is None, not able to run EINI predict algorithm')
  303. if pred_leaf:
  304. raise ValueError('EINI predict mode does not support leaf idx prediction')
  305. # EINI algorithms
  306. id_pos_map_list = get_leaf_idx_map(trees)
  307. map_func = functools.partial(generate_leaf_candidates_guest, sitename=sitename, trees=trees,
  308. node_pos_map_list=id_pos_map_list, init_score=init_score,
  309. learning_rate=learning_rate, booster_dim=booster_dim)
  310. position_vec = data_inst.mapValues(map_func)
  311. # encryption
  312. encrypter = PaillierEncrypt()
  313. encrypter.generate_key(encrypt_key_length)
  314. encrypter_vec_table = position_vec.mapValues(encrypter.recursive_encrypt)
  315. # federation part
  316. # send to first host party
  317. transfer_var.guest_predict_data.remote(encrypter_vec_table, idx=0, suffix='position_vec', role=consts.HOST)
  318. # get from last host party
  319. result_table = transfer_var.host_predict_data.get(idx=len(party_list) - 1, suffix='merge_result', role=consts.HOST)
  320. # decode result
  321. result = result_table.mapValues(encrypter.recursive_decrypt)
  322. # reformat
  323. result = result.mapValues(lambda x: np.array(x))
  324. if predict_cache:
  325. result = result.join(predict_cache, lambda v1, v2: v1 + v2)
  326. return result
  327. def generate_leaf_candidates_host(data_inst, sitename, trees, node_pos_map_list):
  328. candidate_nodes_of_all_tree = []
  329. for tree, node_pos_map in zip(trees, node_pos_map_list):
  330. result_vec = [0 for i in range(len(node_pos_map))]
  331. candidate_list = []
  332. go_to_children_branches(data_inst, tree.tree_node[0], tree, sitename, candidate_list)
  333. for node in candidate_list:
  334. result_vec[node_pos_map[node.id]] = 1 # create 0-1 vector
  335. candidate_nodes_of_all_tree.extend(result_vec)
  336. return np.array(candidate_nodes_of_all_tree)
  337. def generate_leaf_idx_dimension_map(trees, booster_dim):
  338. cur_dim = 0
  339. leaf_dim_map = {}
  340. leaf_idx = 0
  341. for tree in trees:
  342. for node in tree.tree_node:
  343. if node.is_leaf:
  344. leaf_dim_map[leaf_idx] = cur_dim
  345. leaf_idx += 1
  346. cur_dim += 1
  347. if cur_dim == booster_dim:
  348. cur_dim = 0
  349. return leaf_dim_map
  350. def merge_position_vec(host_vec, guest_encrypt_vec, booster_dim=1, leaf_idx_dim_map=None, random_mask=None):
  351. leaf_idx = -1
  352. rs = [0 for i in range(booster_dim)]
  353. for en_num, vec_value in zip(guest_encrypt_vec, host_vec):
  354. leaf_idx += 1
  355. if vec_value == 0:
  356. continue
  357. else:
  358. dim = leaf_idx_dim_map[leaf_idx]
  359. rs[dim] += en_num
  360. if random_mask:
  361. for i in range(len(rs)):
  362. rs[i] = rs[i] * random_mask # a pos random mask btw 1 and 2
  363. return rs
  364. def position_vec_element_wise_mul(guest_encrypt_vec, host_vec):
  365. new_vec = []
  366. for en_num, vec_value in zip(guest_encrypt_vec, host_vec):
  367. new_vec.append(en_num * vec_value)
  368. return new_vec
  369. def count_complexity_helper(node, node_list, host_sitename, meet_host_node):
  370. if node.is_leaf:
  371. return 1 if meet_host_node else 0
  372. if node.sitename == host_sitename:
  373. meet_host_node = True
  374. return count_complexity_helper(node_list[node.left_nodeid], node_list, host_sitename, meet_host_node) + \
  375. count_complexity_helper(node_list[node.right_nodeid], node_list, host_sitename, meet_host_node)
  376. def count_complexity(trees, sitename):
  377. tree_valid_leaves_num = []
  378. for tree in trees:
  379. valid_leaf_num = count_complexity_helper(tree.tree_node[0], tree.tree_node, sitename, False)
  380. if valid_leaf_num != 0:
  381. tree_valid_leaves_num.append(valid_leaf_num)
  382. complexity = 1
  383. for num in tree_valid_leaves_num:
  384. complexity *= num
  385. return complexity
  386. def EINI_host_predict(data_inst, trees: List[HeteroDecisionTreeHost], sitename, self_party_id, party_list,
  387. booster_dim, transfer_var: HeteroSecureBoostTransferVariable,
  388. complexity_check=False, random_mask=False):
  389. if complexity_check:
  390. complexity = count_complexity(trees, sitename)
  391. LOGGER.debug('checking EINI complexity: {}'.format(complexity))
  392. if complexity < consts.EINI_TREE_COMPLEXITY:
  393. raise ValueError('tree complexity: {}, is lower than safe '
  394. 'threshold, inference is not allowed.'.format(complexity))
  395. id_pos_map_list = get_leaf_idx_map(trees)
  396. map_func = functools.partial(generate_leaf_candidates_host, sitename=sitename, trees=trees,
  397. node_pos_map_list=id_pos_map_list)
  398. position_vec = data_inst.mapValues(map_func)
  399. booster_dim = booster_dim
  400. random_mask = random.SystemRandom().random() + 1 if random_mask else 0 # generate a random mask btw 1 and 2
  401. self_idx = party_list.index(self_party_id)
  402. if len(party_list) == 1:
  403. guest_position_vec = transfer_var.guest_predict_data.get(idx=0, suffix='position_vec')
  404. leaf_idx_dim_map = generate_leaf_idx_dimension_map(trees, booster_dim)
  405. merge_func = functools.partial(merge_position_vec, booster_dim=booster_dim,
  406. leaf_idx_dim_map=leaf_idx_dim_map, random_mask=random_mask)
  407. result_table = position_vec.join(guest_position_vec, merge_func)
  408. transfer_var.host_predict_data.remote(result_table, suffix='merge_result')
  409. else:
  410. # multi host case
  411. # if is first host party, get encrypt vec from guest, else from previous host party
  412. if self_party_id == party_list[0]:
  413. guest_position_vec = transfer_var.guest_predict_data.get(idx=0, suffix='position_vec')
  414. else:
  415. guest_position_vec = transfer_var.inter_host_data.get(idx=self_idx - 1, suffix='position_vec')
  416. if self_party_id == party_list[-1]:
  417. leaf_idx_dim_map = generate_leaf_idx_dimension_map(trees, booster_dim)
  418. func = functools.partial(merge_position_vec, booster_dim=booster_dim,
  419. leaf_idx_dim_map=leaf_idx_dim_map, random_mask=random_mask)
  420. result_table = position_vec.join(guest_position_vec, func)
  421. transfer_var.host_predict_data.remote(result_table, suffix='merge_result')
  422. else:
  423. result_table = position_vec.join(guest_position_vec, position_vec_element_wise_mul)
  424. transfer_var.inter_host_data.remote(result_table, idx=self_idx + 1, suffix='position_vec', role=consts.HOST)