homo_secureboost_client.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import copy
  2. import functools
  3. import numpy as np
  4. from typing import List
  5. from operator import itemgetter
  6. from federatedml.util import LOGGER
  7. from federatedml.util import consts
  8. from federatedml.statistic.data_overview import with_weight
  9. from federatedml.feature.sparse_vector import SparseVector
  10. from federatedml.feature.fate_element_type import NoneType
  11. from federatedml.ensemble import HeteroSecureBoostingTreeGuest
  12. from federatedml.util.io_check import assert_io_num_rows_equal
  13. from federatedml.param.boosting_param import HomoSecureBoostParam
  14. from federatedml.ensemble.boosting.homo_boosting import HomoBoostingClient
  15. from federatedml.protobuf.generated.boosting_tree_model_meta_pb2 import QuantileMeta
  16. from federatedml.protobuf.generated.boosting_tree_model_meta_pb2 import ObjectiveMeta
  17. from federatedml.protobuf.generated.boosting_tree_model_meta_pb2 import BoostingTreeModelMeta
  18. from federatedml.protobuf.generated.boosting_tree_model_param_pb2 import FeatureImportanceInfo
  19. from federatedml.protobuf.generated.boosting_tree_model_param_pb2 import BoostingTreeModelParam
  20. from federatedml.ensemble.basic_algorithms.decision_tree.tree_core.feature_importance import FeatureImportance
  21. from federatedml.ensemble.basic_algorithms.decision_tree.homo.homo_decision_tree_client import HomoDecisionTreeClient
  22. make_readable_feature_importance = HeteroSecureBoostingTreeGuest.make_readable_feature_importance
  23. class HomoSecureBoostingTreeClient(HomoBoostingClient):
  24. def __init__(self):
  25. super(HomoSecureBoostingTreeClient, self).__init__()
  26. self.model_name = 'HomoSecureBoost'
  27. self.tree_param = None # decision tree param
  28. self.use_missing = False
  29. self.zero_as_missing = False
  30. self.cur_epoch_idx = -1
  31. self.grad_and_hess = None
  32. self.feature_importance_ = {}
  33. self.model_param = HomoSecureBoostParam()
  34. # memory back end
  35. self.backend = consts.DISTRIBUTED_BACKEND
  36. self.bin_arr, self.sample_id_arr = None, None
  37. # mo tree
  38. self.multi_mode = consts.SINGLE_OUTPUT
  39. def _init_model(self, boosting_param: HomoSecureBoostParam):
  40. super(HomoSecureBoostingTreeClient, self)._init_model(boosting_param)
  41. self.use_missing = boosting_param.use_missing
  42. self.zero_as_missing = boosting_param.zero_as_missing
  43. self.tree_param = boosting_param.tree_param
  44. self.backend = boosting_param.backend
  45. self.multi_mode = boosting_param.multi_mode
  46. if self.use_missing:
  47. self.tree_param.use_missing = self.use_missing
  48. self.tree_param.zero_as_missing = self.zero_as_missing
  49. def get_valid_features(self, epoch_idx, b_idx):
  50. valid_feature = self.transfer_inst.valid_features.get(idx=0, suffix=('valid_features', epoch_idx, b_idx))
  51. return valid_feature
  52. def process_sample_weights(self, grad_and_hess, data_with_sample_weight=None):
  53. # add sample weights to gradient and hessian
  54. if data_with_sample_weight is not None:
  55. if with_weight(data_with_sample_weight):
  56. LOGGER.info('weighted sample detected, multiply g/h by weights')
  57. grad_and_hess = grad_and_hess.join(data_with_sample_weight,
  58. lambda v1, v2: (v1[0] * v2.weight, v1[1] * v2.weight))
  59. return grad_and_hess
  60. def compute_local_grad_and_hess(self, y_hat, data_with_sample_weight):
  61. loss_method = self.loss
  62. if self.task_type == consts.CLASSIFICATION:
  63. grad_and_hess = self.y.join(y_hat, lambda y, f_val:
  64. (loss_method.compute_grad(y, loss_method.predict(f_val)),
  65. loss_method.compute_hess(y, loss_method.predict(f_val))))
  66. else:
  67. grad_and_hess = self.y.join(y_hat, lambda y, f_val:
  68. (loss_method.compute_grad(y, f_val),
  69. loss_method.compute_hess(y, f_val)))
  70. grad_and_hess = self.process_sample_weights(grad_and_hess, data_with_sample_weight)
  71. return grad_and_hess
  72. @staticmethod
  73. def get_subtree_grad_and_hess(g_h, t_idx: int):
  74. """
  75. grad and hess of sub tree
  76. """
  77. LOGGER.info("get grad and hess of tree {}".format(t_idx))
  78. grad_and_hess_subtree = g_h.mapValues(
  79. lambda grad_and_hess: (grad_and_hess[0][t_idx], grad_and_hess[1][t_idx]))
  80. return grad_and_hess_subtree
  81. def update_feature_importance(self, tree_feature_importance):
  82. for fid in tree_feature_importance:
  83. if fid not in self.feature_importance_:
  84. self.feature_importance_[fid] = tree_feature_importance[fid]
  85. else:
  86. self.feature_importance_[fid] += tree_feature_importance[fid]
  87. """
  88. Functions for memory backends
  89. """
  90. @staticmethod
  91. def _handle_zero_as_missing(inst, feat_num, missing_bin_idx):
  92. """
  93. This for use_missing + zero_as_missing case
  94. """
  95. sparse_vec = inst.features.sparse_vec
  96. arr = np.zeros(feat_num, dtype=np.uint8) + missing_bin_idx
  97. for k, v in sparse_vec.items():
  98. if v != NoneType():
  99. arr[k] = v
  100. inst.features = arr
  101. return inst
  102. @staticmethod
  103. def _map_missing_bin(inst, bin_index):
  104. arr_bin = copy.deepcopy(inst.features)
  105. arr_bin[arr_bin == NoneType()] = bin_index
  106. inst.features = arr_bin
  107. return inst
  108. @staticmethod
  109. def _fill_nan(inst):
  110. arr = copy.deepcopy(inst.features)
  111. nan_index = np.isnan(arr)
  112. arr = arr.astype(np.object)
  113. arr[nan_index] = NoneType()
  114. inst.features = arr
  115. return inst
  116. @staticmethod
  117. def _sparse_recover(inst, feat_num):
  118. arr = np.zeros(feat_num)
  119. for k, v in inst.features.sparse_vec.items():
  120. arr[k] = v
  121. inst.features = arr
  122. return inst
  123. def data_preporcess(self, data_inst):
  124. """
  125. override parent function
  126. """
  127. need_transform_to_sparse = self.backend == consts.DISTRIBUTED_BACKEND or \
  128. (self.backend == consts.MEMORY_BACKEND and self.use_missing and self.zero_as_missing)
  129. backup_schema = copy.deepcopy(data_inst.schema)
  130. if self.backend == consts.MEMORY_BACKEND:
  131. # memory backend only support dense format input
  132. data_example = data_inst.take(1)[0][1]
  133. if isinstance(data_example.features, SparseVector):
  134. recover_func = functools.partial(self._sparse_recover, feat_num=len(data_inst.schema['header']))
  135. data_inst = data_inst.mapValues(recover_func)
  136. data_inst.schema = backup_schema
  137. if need_transform_to_sparse:
  138. data_inst = self.data_alignment(data_inst)
  139. elif self.use_missing:
  140. # fill nan
  141. data_inst = data_inst.mapValues(self._fill_nan)
  142. data_inst.schema = backup_schema
  143. self.data_bin, self.bin_split_points, self.bin_sparse_points = self.federated_binning(data_inst)
  144. if self.backend == consts.MEMORY_BACKEND:
  145. if self.use_missing and self.zero_as_missing:
  146. feat_num = len(self.bin_split_points)
  147. func = functools.partial(self._handle_zero_as_missing, feat_num=feat_num, missing_bin_idx=self.bin_num)
  148. self.data_bin = self.data_bin.mapValues(func)
  149. elif self.use_missing: # use missing only
  150. missing_bin_index = self.bin_num
  151. func = functools.partial(self._map_missing_bin, bin_index=missing_bin_index)
  152. self.data_bin = self.data_bin.mapValues(func)
  153. self._collect_data_arr(self.data_bin)
  154. def _collect_data_arr(self, bin_arr_table):
  155. bin_arr = []
  156. id_list = []
  157. for id_, inst in bin_arr_table.collect():
  158. bin_arr.append(inst.features)
  159. id_list.append(id_)
  160. self.bin_arr = np.asfortranarray(np.stack(bin_arr, axis=0).astype(np.uint8))
  161. self.sample_id_arr = np.array(id_list)
  162. def preprocess(self):
  163. if self.multi_mode == consts.MULTI_OUTPUT:
  164. self.booster_dim = 1
  165. LOGGER.debug('multi mode tree dim reset to 1')
  166. def fit_a_learner(self, epoch_idx: int, booster_dim: int):
  167. valid_features = self.get_valid_features(epoch_idx, booster_dim)
  168. LOGGER.debug('valid features are {}'.format(valid_features))
  169. if self.cur_epoch_idx != epoch_idx:
  170. # update g/h every epoch
  171. self.grad_and_hess = self.compute_local_grad_and_hess(self.y_hat, self.data_inst)
  172. self.cur_epoch_idx = epoch_idx
  173. if self.multi_mode == consts.MULTI_OUTPUT:
  174. g_h = self.grad_and_hess
  175. else:
  176. g_h = self.get_subtree_grad_and_hess(self.grad_and_hess, booster_dim)
  177. flow_id = self.generate_flowid(epoch_idx, booster_dim)
  178. new_tree = HomoDecisionTreeClient(
  179. self.tree_param,
  180. self.data_bin,
  181. self.bin_split_points,
  182. self.bin_sparse_points,
  183. g_h,
  184. valid_feature=valid_features,
  185. epoch_idx=epoch_idx,
  186. role=self.role,
  187. flow_id=flow_id,
  188. tree_idx=booster_dim,
  189. mode='train')
  190. if self.backend == consts.DISTRIBUTED_BACKEND:
  191. new_tree.fit()
  192. elif self.backend == consts.MEMORY_BACKEND:
  193. # memory backend needed variable
  194. LOGGER.debug('running memory fit')
  195. new_tree.arr_bin_data = self.bin_arr
  196. new_tree.bin_num = self.bin_num
  197. new_tree.sample_id_arr = self.sample_id_arr
  198. new_tree.memory_fit()
  199. self.update_feature_importance(new_tree.get_feature_importance())
  200. return new_tree
  201. @staticmethod
  202. def predict_helper(data, tree_list: List[HomoDecisionTreeClient], init_score, zero_as_missing, use_missing,
  203. learning_rate, class_num=1):
  204. weight_list = []
  205. for tree in tree_list:
  206. weight = tree.traverse_tree(data, tree.tree_node, use_missing=use_missing, zero_as_missing=zero_as_missing)
  207. weight_list.append(weight)
  208. weights = np.array(weight_list)
  209. if class_num > 2:
  210. weights = weights.reshape((-1, class_num))
  211. return np.sum(weights * learning_rate, axis=0) + init_score
  212. else:
  213. return np.sum(weights * learning_rate, axis=0) + init_score
  214. def fast_homo_tree_predict(self, data_inst, ret_format='std'):
  215. assert ret_format in ['std', 'raw'], 'illegal ret format'
  216. LOGGER.info('running fast homo tree predict')
  217. to_predict_data = self.data_and_header_alignment(data_inst)
  218. tree_list = []
  219. rounds = len(self.boosting_model_list) // self.booster_dim
  220. for idx in range(0, rounds):
  221. for booster_idx in range(self.booster_dim):
  222. model = self.load_learner(self.booster_meta,
  223. self.boosting_model_list[idx * self.booster_dim + booster_idx],
  224. idx, booster_idx)
  225. tree_list.append(model)
  226. func = functools.partial(self.predict_helper, tree_list=tree_list, init_score=self.init_score,
  227. zero_as_missing=self.zero_as_missing, use_missing=self.use_missing,
  228. learning_rate=self.learning_rate, class_num=self.booster_dim)
  229. predict_rs = to_predict_data.mapValues(func)
  230. if ret_format == 'std':
  231. return self.score_to_predict_result(data_inst, predict_rs)
  232. elif ret_format == 'raw':
  233. return predict_rs
  234. else:
  235. raise ValueError('illegal ret format')
  236. @assert_io_num_rows_equal
  237. def predict(self, data_inst, ret_format='std'):
  238. return self.fast_homo_tree_predict(data_inst, ret_format=ret_format)
  239. def generate_summary(self) -> dict:
  240. summary = {'feature_importance': make_readable_feature_importance(self.feature_name_fid_mapping,
  241. self.feature_importance_),
  242. 'validation_metrics': self.callback_variables.validation_summary}
  243. return summary
  244. def load_learner(self, model_meta, model_param, epoch_idx, booster_idx):
  245. tree_inst = HomoDecisionTreeClient(tree_param=self.tree_param, mode='predict')
  246. tree_inst.load_model(model_meta=model_meta, model_param=model_param)
  247. return tree_inst
  248. def load_feature_importance(self, feat_importance_param):
  249. param = list(feat_importance_param)
  250. rs_dict = {}
  251. for fp in param:
  252. key = fp.fid
  253. importance = FeatureImportance()
  254. importance.from_protobuf(fp)
  255. rs_dict[key] = importance
  256. self.feature_importance_ = rs_dict
  257. LOGGER.debug('load feature importance": {}'.format(self.feature_importance_))
  258. def set_model_param(self, model_param):
  259. self.boosting_model_list = list(model_param.trees_)
  260. self.init_score = np.array(list(model_param.init_score))
  261. self.classes_ = list(map(int, model_param.classes_))
  262. self.booster_dim = model_param.tree_dim
  263. self.num_classes = model_param.num_classes
  264. self.feature_name_fid_mapping.update(model_param.feature_name_fid_mapping)
  265. self.load_feature_importance(model_param.feature_importances)
  266. # initialize loss function
  267. self.loss = self.get_loss_function()
  268. def set_model_meta(self, model_meta):
  269. if not self.is_warm_start:
  270. self.boosting_round = model_meta.num_trees
  271. self.n_iter_no_change = model_meta.n_iter_no_change
  272. self.tol = model_meta.tol
  273. self.bin_num = model_meta.quantile_meta.bin_num
  274. self.learning_rate = model_meta.learning_rate
  275. self.booster_meta = model_meta.tree_meta
  276. self.objective_param.objective = model_meta.objective_meta.objective
  277. self.objective_param.params = list(model_meta.objective_meta.param)
  278. self.task_type = model_meta.task_type
  279. def get_model_param(self):
  280. model_param = BoostingTreeModelParam()
  281. model_param.tree_num = len(list(self.boosting_model_list))
  282. model_param.tree_dim = self.booster_dim
  283. model_param.trees_.extend(self.boosting_model_list)
  284. model_param.init_score.extend(self.init_score)
  285. model_param.classes_.extend(map(str, self.classes_))
  286. model_param.num_classes = self.num_classes
  287. model_param.best_iteration = -1
  288. model_param.model_name = consts.HOMO_SBT
  289. feature_importance = list(self.feature_importance_.items())
  290. feature_importance = sorted(feature_importance, key=itemgetter(1), reverse=True)
  291. feature_importance_param = []
  292. for fid, importance in feature_importance:
  293. feature_importance_param.append(FeatureImportanceInfo(fid=fid,
  294. fullname=self.feature_name_fid_mapping[fid],
  295. sitename=self.role,
  296. importance=importance.importance,
  297. importance2=importance.importance_2,
  298. main=importance.main_type
  299. ))
  300. model_param.feature_importances.extend(feature_importance_param)
  301. model_param.feature_name_fid_mapping.update(self.feature_name_fid_mapping)
  302. param_name = "HomoSecureBoostingTreeGuestParam"
  303. return param_name, model_param
  304. def get_model_meta(self):
  305. model_meta = BoostingTreeModelMeta()
  306. model_meta.tree_meta.CopyFrom(self.booster_meta)
  307. model_meta.learning_rate = self.learning_rate
  308. model_meta.num_trees = self.boosting_round
  309. model_meta.quantile_meta.CopyFrom(QuantileMeta(bin_num=self.bin_num))
  310. model_meta.objective_meta.CopyFrom(ObjectiveMeta(objective=self.objective_param.objective,
  311. param=self.objective_param.params))
  312. model_meta.task_type = self.task_type
  313. model_meta.n_iter_no_change = self.n_iter_no_change
  314. model_meta.tol = self.tol
  315. model_meta.use_missing = self.use_missing
  316. model_meta.zero_as_missing = self.zero_as_missing
  317. model_meta.module = 'HomoSecureBoost'
  318. meta_name = "HomoSecureBoostingTreeGuestMeta"
  319. return meta_name, model_meta