123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571 |
- import functools
- import numpy as np
- import random
- from typing import List
- from federatedml.util import consts
- from federatedml.secureprotol import PaillierEncrypt
- from federatedml.ensemble.basic_algorithms import HeteroDecisionTreeGuest, HeteroDecisionTreeHost, \
- HeteroFastDecisionTreeGuest, HeteroFastDecisionTreeHost
- from federatedml.ensemble.basic_algorithms.decision_tree.tree_core.decision_tree import DecisionTree, Node
- from federatedml.util import LOGGER
- from federatedml.transfer_variable.transfer_class.hetero_secure_boosting_predict_transfer_variable import \
- HeteroSecureBoostTransferVariable
- """
- Hetero guest predict utils
- """
- def generate_leaf_pos_dict(x, tree_num, np_int_type=np.int8):
- """
- x: just occupy the first parameter position
- return: a numpy array record sample pos, and a counter counting how many trees reach a leaf node
- """
- node_pos = np.zeros(tree_num, dtype=np_int_type)
- reach_leaf_node = np.zeros(tree_num, dtype=np.bool)
- return {'node_pos': node_pos, 'reach_leaf_node': reach_leaf_node}
- def guest_traverse_a_tree(tree: HeteroDecisionTreeGuest, sample, cur_node_idx):
- reach_leaf = False
- # only need nid here, predict state is not needed
- rs = tree.traverse_tree(tree_=tree.tree_node, data_inst=sample, predict_state=(cur_node_idx, -1),
- decoder=tree.decode, sitename=tree.sitename, use_missing=tree.use_missing,
- split_maskdict=tree.split_maskdict, missing_dir_maskdict=tree.missing_dir_maskdict,
- zero_as_missing=tree.zero_as_missing, return_leaf_id=True)
- if not isinstance(rs, tuple):
- reach_leaf = True
- leaf_id = rs
- return leaf_id, reach_leaf
- else:
- cur_node_idx = rs[0]
- return cur_node_idx, reach_leaf
- def guest_traverse_trees(node_pos, sample, trees: List[HeteroDecisionTreeGuest]):
- if node_pos['reach_leaf_node'].all():
- return node_pos
- for t_idx, tree in enumerate(trees):
- cur_node_idx = node_pos['node_pos'][t_idx]
- # reach leaf
- if cur_node_idx == -1:
- continue
- rs, reach_leaf = guest_traverse_a_tree(tree, sample, cur_node_idx)
- if reach_leaf:
- node_pos['reach_leaf_node'][t_idx] = True
- node_pos['node_pos'][t_idx] = rs
- return node_pos
- def merge_predict_pos(node_pos1, node_pos2):
- pos_arr1 = node_pos1['node_pos']
- pos_arr2 = node_pos2['node_pos']
- stack_arr = np.stack([pos_arr1, pos_arr2])
- node_pos1['node_pos'] = np.max(stack_arr, axis=0)
- return node_pos1
- def add_y_hat(leaf_pos, init_score, learning_rate, trees: List[HeteroDecisionTreeGuest], multi_class_num=None):
- # finally node pos will hold weights
- weights = []
- for leaf_idx, tree in zip(leaf_pos, trees):
- weights.append(tree.tree_node[int(leaf_idx)].weight)
- weights = np.array(weights)
- if multi_class_num > 2:
- weights = weights.reshape((-1, multi_class_num))
- return np.sum(weights * learning_rate, axis=0) + init_score
- def get_predict_scores(
- leaf_pos,
- learning_rate,
- init_score,
- trees: List[HeteroDecisionTreeGuest],
- multi_class_num=-1,
- predict_cache=None):
- if predict_cache:
- init_score = 0 # prevent init_score re-add
- predict_func = functools.partial(add_y_hat,
- learning_rate=learning_rate, init_score=init_score, trees=trees,
- multi_class_num=multi_class_num)
- predict_result = leaf_pos.mapValues(predict_func)
- if predict_cache:
- predict_result = predict_result.join(predict_cache, lambda v1, v2: v1 + v2)
- return predict_result
- def save_leaf_pos_helper(v1, v2):
- reach_leaf_idx = v2['reach_leaf_node']
- select_idx = reach_leaf_idx & (v2['node_pos'] != -1) # reach leaf and are not recorded( if recorded idx is -1)
- v1[select_idx] = v2['node_pos'][select_idx]
- return v1
- def mask_leaf_pos(v):
- reach_leaf_idx = v['reach_leaf_node']
- v['node_pos'][reach_leaf_idx] = -1
- return v
- def save_leaf_pos_and_mask_leaf_pos(node_pos_tb, final_leaf_pos):
- # save leaf pos
- saved_leaf_pos = final_leaf_pos.join(node_pos_tb, save_leaf_pos_helper)
- rest_part = final_leaf_pos.subtractByKey(saved_leaf_pos)
- final_leaf_pos = saved_leaf_pos.union(rest_part)
- # mask leaf pos
- node_pos_tb = node_pos_tb.mapValues(mask_leaf_pos)
- return node_pos_tb, final_leaf_pos
- def merge_leaf_pos(pos1, pos2):
- return pos1 + pos2
- def traverse_guest_local_trees(node_pos, sample, trees: List[HeteroFastDecisionTreeGuest]):
- """
- in mix mode, a sample can reach leaf directly
- """
- for t_idx, tree in enumerate(trees):
- cur_node_idx = node_pos[t_idx]
- if not tree.use_guest_feat_only_predict_mode:
- continue
- rs, reach_leaf = guest_traverse_a_tree(tree, sample, cur_node_idx)
- node_pos[t_idx] = rs
- return node_pos
- """
- Hetero guest predict function
- """
- def get_dtype(max_int):
- if max_int < (2**8) / 2:
- return np.int8
- elif max_int < (2**16) / 2:
- return np.int16
- else:
- return np.int64
- def sbt_guest_predict(data_inst, transfer_var: HeteroSecureBoostTransferVariable,
- trees: List[HeteroDecisionTreeGuest], learning_rate, init_score, booster_dim,
- predict_cache=None, pred_leaf=False):
- tree_num = len(trees)
- max_depth = trees[0].max_depth
- max_int = 2 ** max_depth
- dtype = get_dtype(max_int)
- LOGGER.debug('chosen np dtype is {}'.format(dtype))
- generate_func = functools.partial(generate_leaf_pos_dict, tree_num=tree_num, np_int_type=dtype)
- node_pos_tb = data_inst.mapValues(generate_func) # record node pos
- final_leaf_pos = data_inst.mapValues(lambda x: np.zeros(tree_num, dtype=dtype) + np.nan) # record final leaf pos
- traverse_func = functools.partial(guest_traverse_trees, trees=trees)
- comm_round = 0
- while True:
- # LOGGER.info('cur predict round is {}'.format(comm_round))
- node_pos_tb = node_pos_tb.join(data_inst, traverse_func)
- node_pos_tb, final_leaf_pos = save_leaf_pos_and_mask_leaf_pos(node_pos_tb, final_leaf_pos)
- # remove sample that reaches leaves of all trees
- reach_leaf_samples = node_pos_tb.filter(lambda key, value: value['reach_leaf_node'].all())
- node_pos_tb = node_pos_tb.subtractByKey(reach_leaf_samples)
- if node_pos_tb.count() == 0:
- transfer_var.predict_stop_flag.remote(True, idx=-1, suffix=(comm_round,))
- break
- transfer_var.predict_stop_flag.remote(False, idx=-1, suffix=(comm_round,))
- transfer_var.guest_predict_data.remote(node_pos_tb, idx=-1, suffix=(comm_round,))
- host_pos_tbs = transfer_var.host_predict_data.get(idx=-1, suffix=(comm_round,))
- for host_pos_tb in host_pos_tbs:
- node_pos_tb = node_pos_tb.join(host_pos_tb, merge_predict_pos)
- comm_round += 1
- if pred_leaf: # return leaf position only
- return final_leaf_pos
- else: # get final predict scores from leaf pos
- predict_result = get_predict_scores(leaf_pos=final_leaf_pos, learning_rate=learning_rate,
- init_score=init_score, trees=trees,
- multi_class_num=booster_dim, predict_cache=predict_cache)
- return predict_result
- def mix_sbt_guest_predict(data_inst, transfer_var: HeteroSecureBoostTransferVariable,
- trees: List[HeteroDecisionTreeGuest], learning_rate, init_score, booster_dim,
- predict_cache=None, pred_leaf=False):
- LOGGER.info('running mix mode predict')
- tree_num = len(trees)
- node_pos = data_inst.mapValues(lambda x: np.zeros(tree_num, dtype=np.int64))
- # traverse local trees
- traverse_func = functools.partial(traverse_guest_local_trees, trees=trees)
- guest_leaf_pos = node_pos.join(data_inst, traverse_func)
- # get leaf node from other host parties
- host_leaf_pos_list = transfer_var.host_predict_data.get(idx=-1)
- for host_leaf_pos in host_leaf_pos_list:
- guest_leaf_pos = guest_leaf_pos.join(host_leaf_pos, merge_leaf_pos)
- if pred_leaf: # predict leaf, return leaf position only
- return guest_leaf_pos
- else:
- predict_result = get_predict_scores(leaf_pos=guest_leaf_pos, learning_rate=learning_rate,
- init_score=init_score, trees=trees,
- multi_class_num=booster_dim, predict_cache=predict_cache)
- return predict_result
- """
- Hetero host predict utils
- """
- def host_traverse_a_tree(tree: HeteroDecisionTreeHost, sample, cur_node_idx):
- nid, _ = tree.traverse_tree(predict_state=(cur_node_idx, -1), data_inst=sample,
- decoder=tree.decode, split_maskdict=tree.split_maskdict,
- missing_dir_maskdict=tree.missing_dir_maskdict, sitename=tree.sitename,
- tree_=tree.tree_node, zero_as_missing=tree.zero_as_missing,
- use_missing=tree.use_missing)
- return nid, _
- def host_traverse_trees(sample, leaf_pos, trees: List[HeteroDecisionTreeHost]):
- for t_idx, tree in enumerate(trees):
- cur_node_idx = leaf_pos['node_pos'][t_idx]
- # idx is set as -1 when a sample reaches leaf
- if cur_node_idx == -1:
- continue
- nid, _ = host_traverse_a_tree(tree, sample, cur_node_idx)
- leaf_pos['node_pos'][t_idx] = nid
- return leaf_pos
- def traverse_host_local_trees(node_pos, sample, trees: List[HeteroFastDecisionTreeHost]):
- """
- in mix mode, a sample can reach leaf directly
- """
- for i in range(len(trees)):
- tree = trees[i]
- if len(tree.tree_node) == 0: # this tree belongs to other party because it has no tree node
- continue
- leaf_id = tree.host_local_traverse_tree(sample, tree.tree_node, use_missing=tree.use_missing,
- zero_as_missing=tree.zero_as_missing)
- node_pos[i] = leaf_id
- return node_pos
- """
- Hetero host predict function
- """
- def sbt_host_predict(data_inst, transfer_var: HeteroSecureBoostTransferVariable, trees: List[HeteroDecisionTreeHost]):
- comm_round = 0
- traverse_func = functools.partial(host_traverse_trees, trees=trees)
- while True:
- LOGGER.debug('cur predict round is {}'.format(comm_round))
- stop_flag = transfer_var.predict_stop_flag.get(idx=0, suffix=(comm_round,))
- if stop_flag:
- break
- guest_node_pos = transfer_var.guest_predict_data.get(idx=0, suffix=(comm_round,))
- host_node_pos = data_inst.join(guest_node_pos, traverse_func)
- if guest_node_pos.count() != host_node_pos.count():
- raise ValueError('sample count mismatch: guest table {}, host table {}'.format(guest_node_pos.count(),
- host_node_pos.count()))
- transfer_var.host_predict_data.remote(host_node_pos, idx=-1, suffix=(comm_round,))
- comm_round += 1
- def mix_sbt_host_predict(data_inst, transfer_var: HeteroSecureBoostTransferVariable,
- trees: List[HeteroDecisionTreeHost]):
- LOGGER.info('running mix mode predict')
- tree_num = len(trees)
- node_pos = data_inst.mapValues(lambda x: np.zeros(tree_num, dtype=np.int64))
- local_traverse_func = functools.partial(traverse_host_local_trees, trees=trees)
- leaf_pos = node_pos.join(data_inst, local_traverse_func)
- transfer_var.host_predict_data.remote(leaf_pos, idx=0, role=consts.GUEST)
- """
- Fed-EINI predict func
- """
- def get_leaf_idx_map(trees):
- id_pos_map_list = []
- for tree in trees:
- array_idx = 0
- id_pos_map = {}
- for node in tree.tree_node:
- if node.is_leaf:
- id_pos_map[node.id] = array_idx
- array_idx += 1
- id_pos_map_list.append(id_pos_map)
- return id_pos_map_list
- def go_to_children_branches(data_inst, tree_node, tree, sitename: str, candidate_list: List):
- if tree_node.is_leaf:
- candidate_list.append(tree_node)
- else:
- tree_node_list = tree.tree_node
- if tree_node.sitename != sitename:
- go_to_children_branches(data_inst, tree_node_list[tree_node.left_nodeid],
- tree, sitename, candidate_list)
- go_to_children_branches(data_inst, tree_node_list[tree_node.right_nodeid],
- tree, sitename, candidate_list)
- else:
- next_layer_node_id = tree.go_next_layer(tree_node, data_inst, use_missing=tree.use_missing,
- zero_as_missing=tree.zero_as_missing, decoder=tree.decode,
- split_maskdict=tree.split_maskdict,
- missing_dir_maskdict=tree.missing_dir_maskdict,
- bin_sparse_point=None
- )
- go_to_children_branches(data_inst, tree_node_list[next_layer_node_id], tree, sitename, candidate_list)
- def generate_leaf_candidates_guest(data_inst, sitename, trees, node_pos_map_list,
- init_score, learning_rate, booster_dim):
- candidate_nodes_of_all_tree = []
- if booster_dim > 2:
- epoch_num = len(trees) // booster_dim
- else:
- epoch_num = len(trees)
- init_score = init_score / epoch_num
- score_idx = 0
- for tree, node_pos_map in zip(trees, node_pos_map_list):
- if booster_dim > 2:
- tree_init_score = init_score[score_idx]
- score_idx += 1
- if score_idx == booster_dim:
- score_idx = 0
- else:
- tree_init_score = init_score
- candidate_list = []
- go_to_children_branches(data_inst, tree.tree_node[0], tree, sitename, candidate_list)
- # check if it is mo tree:
- if len(candidate_list) < 1:
- raise ValueError('incorrect candidate list length,: {}'.format(len(candidate_list)))
- node = candidate_list[0]
- result_vec = np.zeros(len(node_pos_map))
- if isinstance(node.weight, np.ndarray):
- if len(node.weight) > 1:
- result_vec = [np.array([0 for i in range(len(node.weight))]) for i in range(len(node_pos_map))]
- for node in candidate_list:
- result_vec[node_pos_map[node.id]] = node.weight * learning_rate + tree_init_score
- candidate_nodes_of_all_tree.extend(result_vec)
- return np.array(candidate_nodes_of_all_tree)
- def EINI_guest_predict(data_inst, trees: List[HeteroDecisionTreeGuest], learning_rate, init_score, booster_dim,
- encrypt_key_length, transfer_var: HeteroSecureBoostTransferVariable,
- sitename=None, party_list=None, predict_cache=None, pred_leaf=False):
- if sitename is None:
- raise ValueError('input sitename is None, not able to run EINI predict algorithm')
- if pred_leaf:
- raise ValueError('EINI predict mode does not support leaf idx prediction')
- # EINI algorithms
- id_pos_map_list = get_leaf_idx_map(trees)
- map_func = functools.partial(generate_leaf_candidates_guest, sitename=sitename, trees=trees,
- node_pos_map_list=id_pos_map_list, init_score=init_score,
- learning_rate=learning_rate, booster_dim=booster_dim)
- position_vec = data_inst.mapValues(map_func)
- # encryption
- encrypter = PaillierEncrypt()
- encrypter.generate_key(encrypt_key_length)
- encrypter_vec_table = position_vec.mapValues(encrypter.recursive_encrypt)
- # federation part
- # send to first host party
- transfer_var.guest_predict_data.remote(encrypter_vec_table, idx=0, suffix='position_vec', role=consts.HOST)
- # get from last host party
- result_table = transfer_var.host_predict_data.get(idx=len(party_list) - 1, suffix='merge_result', role=consts.HOST)
- # decode result
- result = result_table.mapValues(encrypter.recursive_decrypt)
- # reformat
- result = result.mapValues(lambda x: np.array(x))
- if predict_cache:
- result = result.join(predict_cache, lambda v1, v2: v1 + v2)
- return result
- def generate_leaf_candidates_host(data_inst, sitename, trees, node_pos_map_list):
- candidate_nodes_of_all_tree = []
- for tree, node_pos_map in zip(trees, node_pos_map_list):
- result_vec = [0 for i in range(len(node_pos_map))]
- candidate_list = []
- go_to_children_branches(data_inst, tree.tree_node[0], tree, sitename, candidate_list)
- for node in candidate_list:
- result_vec[node_pos_map[node.id]] = 1 # create 0-1 vector
- candidate_nodes_of_all_tree.extend(result_vec)
- return np.array(candidate_nodes_of_all_tree)
- def generate_leaf_idx_dimension_map(trees, booster_dim):
- cur_dim = 0
- leaf_dim_map = {}
- leaf_idx = 0
- for tree in trees:
- for node in tree.tree_node:
- if node.is_leaf:
- leaf_dim_map[leaf_idx] = cur_dim
- leaf_idx += 1
- cur_dim += 1
- if cur_dim == booster_dim:
- cur_dim = 0
- return leaf_dim_map
- def merge_position_vec(host_vec, guest_encrypt_vec, booster_dim=1, leaf_idx_dim_map=None, random_mask=None):
- leaf_idx = -1
- rs = [0 for i in range(booster_dim)]
- for en_num, vec_value in zip(guest_encrypt_vec, host_vec):
- leaf_idx += 1
- if vec_value == 0:
- continue
- else:
- dim = leaf_idx_dim_map[leaf_idx]
- rs[dim] += en_num
- if random_mask:
- for i in range(len(rs)):
- rs[i] = rs[i] * random_mask # a pos random mask btw 1 and 2
- return rs
- def position_vec_element_wise_mul(guest_encrypt_vec, host_vec):
- new_vec = []
- for en_num, vec_value in zip(guest_encrypt_vec, host_vec):
- new_vec.append(en_num * vec_value)
- return new_vec
- def count_complexity_helper(node, node_list, host_sitename, meet_host_node):
- if node.is_leaf:
- return 1 if meet_host_node else 0
- if node.sitename == host_sitename:
- meet_host_node = True
- return count_complexity_helper(node_list[node.left_nodeid], node_list, host_sitename, meet_host_node) + \
- count_complexity_helper(node_list[node.right_nodeid], node_list, host_sitename, meet_host_node)
- def count_complexity(trees, sitename):
- tree_valid_leaves_num = []
- for tree in trees:
- valid_leaf_num = count_complexity_helper(tree.tree_node[0], tree.tree_node, sitename, False)
- if valid_leaf_num != 0:
- tree_valid_leaves_num.append(valid_leaf_num)
- complexity = 1
- for num in tree_valid_leaves_num:
- complexity *= num
- return complexity
- def EINI_host_predict(data_inst, trees: List[HeteroDecisionTreeHost], sitename, self_party_id, party_list,
- booster_dim, transfer_var: HeteroSecureBoostTransferVariable,
- complexity_check=False, random_mask=False):
- if complexity_check:
- complexity = count_complexity(trees, sitename)
- LOGGER.debug('checking EINI complexity: {}'.format(complexity))
- if complexity < consts.EINI_TREE_COMPLEXITY:
- raise ValueError('tree complexity: {}, is lower than safe '
- 'threshold, inference is not allowed.'.format(complexity))
- id_pos_map_list = get_leaf_idx_map(trees)
- map_func = functools.partial(generate_leaf_candidates_host, sitename=sitename, trees=trees,
- node_pos_map_list=id_pos_map_list)
- position_vec = data_inst.mapValues(map_func)
- booster_dim = booster_dim
- random_mask = random.SystemRandom().random() + 1 if random_mask else 0 # generate a random mask btw 1 and 2
- self_idx = party_list.index(self_party_id)
- if len(party_list) == 1:
- guest_position_vec = transfer_var.guest_predict_data.get(idx=0, suffix='position_vec')
- leaf_idx_dim_map = generate_leaf_idx_dimension_map(trees, booster_dim)
- merge_func = functools.partial(merge_position_vec, booster_dim=booster_dim,
- leaf_idx_dim_map=leaf_idx_dim_map, random_mask=random_mask)
- result_table = position_vec.join(guest_position_vec, merge_func)
- transfer_var.host_predict_data.remote(result_table, suffix='merge_result')
- else:
- # multi host case
- # if is first host party, get encrypt vec from guest, else from previous host party
- if self_party_id == party_list[0]:
- guest_position_vec = transfer_var.guest_predict_data.get(idx=0, suffix='position_vec')
- else:
- guest_position_vec = transfer_var.inter_host_data.get(idx=self_idx - 1, suffix='position_vec')
- if self_party_id == party_list[-1]:
- leaf_idx_dim_map = generate_leaf_idx_dimension_map(trees, booster_dim)
- func = functools.partial(merge_position_vec, booster_dim=booster_dim,
- leaf_idx_dim_map=leaf_idx_dim_map, random_mask=random_mask)
- result_table = position_vec.join(guest_position_vec, func)
- transfer_var.host_predict_data.remote(result_table, suffix='merge_result')
- else:
- result_table = position_vec.join(guest_position_vec, position_vec_element_wise_mul)
- transfer_var.inter_host_data.remote(result_table, idx=self_idx + 1, suffix='position_vec', role=consts.HOST)
|