homo_boosting.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. from abc import ABC
  2. import abc
  3. import numpy as np
  4. from federatedml.ensemble.boosting.boosting import Boosting
  5. from federatedml.feature.homo_feature_binning.homo_split_points import HomoFeatureBinningClient, \
  6. HomoFeatureBinningServer
  7. from federatedml.util.classify_label_checker import ClassifyLabelChecker, RegressionLabelChecker
  8. from federatedml.util import consts
  9. from federatedml.util.homo_label_encoder import HomoLabelEncoderClient, HomoLabelEncoderArbiter
  10. from federatedml.transfer_variable.transfer_class.homo_boosting_transfer_variable import HomoBoostingTransferVariable
  11. from typing import List
  12. from federatedml.feature.fate_element_type import NoneType
  13. from federatedml.util import LOGGER
  14. from federatedml.optim.convergence import converge_func_factory
  15. from federatedml.param.boosting_param import HomoSecureBoostParam
  16. from federatedml.model_base import Metric
  17. from federatedml.model_base import MetricMeta
  18. from federatedml.util.io_check import assert_io_num_rows_equal
  19. from federatedml.feature.homo_feature_binning import recursive_query_binning
  20. from federatedml.param.feature_binning_param import HomoFeatureBinningParam
  21. from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
  22. class HomoBoostArbiterAggregator(object):
  23. def __init__(self, ):
  24. self.aggregator = SecureAggregatorServer(communicate_match_suffix='homo_sbt')
  25. def aggregate_loss(self, suffix):
  26. global_loss = self.aggregator.aggregate_loss(suffix)
  27. return global_loss
  28. def broadcast_converge_status(self, func, loss, suffix):
  29. is_converged = func(*loss)
  30. self.aggregator.broadcast_converge_status(is_converged, suffix=suffix)
  31. return is_converged
  32. class HomoBoostClientAggregator(object):
  33. def __init__(self, sample_num):
  34. self.aggregator = SecureAggregatorClient(
  35. communicate_match_suffix='homo_sbt', aggregate_weight=sample_num)
  36. def send_local_loss(self, loss, suffix):
  37. self.aggregator.send_loss(loss, suffix)
  38. def get_converge_status(self, suffix):
  39. return self.aggregator.get_converge_status(suffix)
  40. class HomoBoostingClient(Boosting, ABC):
  41. def __init__(self):
  42. super(HomoBoostingClient, self).__init__()
  43. self.transfer_inst = HomoBoostingTransferVariable()
  44. self.model_param = HomoSecureBoostParam()
  45. self.aggregator = None
  46. self.binning_obj = None
  47. self.mode = consts.HOMO
  48. def federated_binning(self, data_instance):
  49. binning_param = HomoFeatureBinningParam(method=consts.RECURSIVE_QUERY, bin_num=self.bin_num,
  50. error=self.binning_error)
  51. if self.use_missing:
  52. self.binning_obj = recursive_query_binning.Client(params=binning_param, abnormal_list=[NoneType()],
  53. role=self.role)
  54. LOGGER.debug('use missing')
  55. else:
  56. self.binning_obj = recursive_query_binning.Client(params=binning_param, role=self.role)
  57. self.binning_obj.fit_split_points(data_instance)
  58. return self.binning_obj.convert_feature_to_bin(data_instance)
  59. def check_label(self, data_inst, ) -> List[int]:
  60. LOGGER.debug('checking labels')
  61. classes_ = None
  62. if self.task_type == consts.CLASSIFICATION:
  63. num_classes, classes_ = ClassifyLabelChecker.validate_label(data_inst)
  64. else:
  65. RegressionLabelChecker.validate_label(data_inst)
  66. return classes_
  67. @staticmethod
  68. def check_label_starts_from_zero(aligned_labels):
  69. """
  70. in current version, labels should start from 0 and
  71. are consecutive integers
  72. """
  73. if aligned_labels[0] != 0:
  74. raise ValueError('label should starts from 0')
  75. for prev, aft in zip(aligned_labels[:-1], aligned_labels[1:]):
  76. if prev + 1 != aft:
  77. raise ValueError('labels should be a sequence of consecutive integers, '
  78. 'but got {} and {}'.format(prev, aft))
  79. def sync_feature_num(self):
  80. self.transfer_inst.feature_number.remote(self.feature_num, role=consts.ARBITER, idx=-1, suffix=('feat_num',))
  81. def sync_start_round_and_end_round(self):
  82. self.transfer_inst.start_and_end_round.remote((self.start_round, self.boosting_round),
  83. role=consts.ARBITER, idx=-1)
  84. def data_preporcess(self, data_inst):
  85. # transform to sparse and binning
  86. data_inst = self.data_alignment(data_inst)
  87. self.data_bin, self.bin_split_points, self.bin_sparse_points = self.federated_binning(data_inst)
  88. def fit(self, data_inst, validate_data=None):
  89. # init federation obj
  90. self.aggregator = HomoBoostClientAggregator(sample_num=data_inst.count())
  91. # binning
  92. self.data_preporcess(data_inst)
  93. self.data_inst = data_inst
  94. # fid mapping and warm start check
  95. if not self.is_warm_start:
  96. self.feature_name_fid_mapping = self.gen_feature_fid_mapping(data_inst.schema)
  97. else:
  98. self.feat_name_check(data_inst, self.feature_name_fid_mapping)
  99. # set feature_num
  100. self.feature_num = self.bin_split_points.shape[0]
  101. # sync feature num
  102. self.sync_feature_num()
  103. # initialize validation strategy
  104. self.callback_list.on_train_begin(data_inst, validate_data)
  105. # check labels
  106. local_classes = self.check_label(self.data_bin)
  107. # set start round
  108. self.start_round = len(self.boosting_model_list) // self.booster_dim
  109. # sync label class and set y
  110. if self.task_type == consts.CLASSIFICATION:
  111. aligned_label, new_label_mapping = HomoLabelEncoderClient().label_alignment(local_classes)
  112. if self.is_warm_start:
  113. assert set(aligned_label) == set(self.classes_), 'warm start label alignment failed, differences: {}'. \
  114. format(set(aligned_label).symmetric_difference(set(self.classes_)))
  115. self.classes_ = aligned_label
  116. self.check_label_starts_from_zero(self.classes_)
  117. # set labels
  118. self.num_classes = len(new_label_mapping)
  119. LOGGER.info('aligned labels are {}, num_classes is {}'.format(aligned_label, self.num_classes))
  120. self.y = self.data_bin.mapValues(lambda instance: new_label_mapping[instance.label])
  121. # set tree dimension
  122. self.booster_dim = self.num_classes if self.num_classes > 2 else 1
  123. else:
  124. self.y = self.data_bin.mapValues(lambda instance: instance.label)
  125. # set loss function
  126. self.loss = self.get_loss_function()
  127. # set y_hat_val, if warm start predict cur samples
  128. if self.is_warm_start:
  129. self.y_hat = self.predict(data_inst, ret_format='raw')
  130. self.boosting_round += self.start_round
  131. self.callback_warm_start_init_iter(self.start_round)
  132. else:
  133. if self.task_type == consts.REGRESSION:
  134. self.init_score = np.array([0]) # make sure that every local model has same init scores
  135. self.y_hat = self.y.mapValues(lambda x: np.array([0]))
  136. else:
  137. self.y_hat, self.init_score = self.get_init_score(self.y, self.num_classes)
  138. # sync start round and end round
  139. self.sync_start_round_and_end_round()
  140. self.preprocess()
  141. LOGGER.info('begin to fit a boosting tree')
  142. for epoch_idx in range(self.start_round, self.boosting_round):
  143. LOGGER.info('cur epoch idx is {}'.format(epoch_idx))
  144. self.callback_list.on_epoch_begin(epoch_idx)
  145. for class_idx in range(self.booster_dim):
  146. # fit a booster
  147. model = self.fit_a_learner(epoch_idx, class_idx)
  148. booster_meta, booster_param = model.get_model()
  149. if booster_meta is not None and booster_param is not None:
  150. self.booster_meta = booster_meta
  151. self.boosting_model_list.append(booster_param)
  152. # update predict score
  153. cur_sample_weights = model.get_sample_weights()
  154. self.y_hat = self.get_new_predict_score(self.y_hat, cur_sample_weights, dim=class_idx)
  155. local_loss = self.compute_loss(self.y_hat, self.y)
  156. self.aggregator.send_local_loss(local_loss, suffix=(epoch_idx,))
  157. validation_strategy = self.callback_list.get_validation_strategy()
  158. if validation_strategy:
  159. validation_strategy.set_precomputed_train_scores(self.score_to_predict_result(data_inst, self.y_hat))
  160. self.callback_list.on_epoch_end(epoch_idx)
  161. # check stop flag if n_iter_no_change is True
  162. if self.n_iter_no_change:
  163. should_stop = self.aggregator.get_converge_status(suffix=(str(epoch_idx),))
  164. if should_stop:
  165. LOGGER.info('n_iter_no_change stop triggered')
  166. break
  167. self.postprocess()
  168. self.callback_list.on_train_end()
  169. self.set_summary(self.generate_summary())
  170. @assert_io_num_rows_equal
  171. def predict(self, data_inst):
  172. # predict is implemented in homo_secureboost
  173. raise NotImplementedError('predict func is not implemented')
  174. @abc.abstractmethod
  175. def fit_a_learner(self, epoch_idx: int, booster_dim: int):
  176. raise NotImplementedError()
  177. @abc.abstractmethod
  178. def load_learner(self, model_meta, model_param, epoch_idx, booster_idx):
  179. raise NotImplementedError()
  180. class HomoBoostingArbiter(Boosting, ABC):
  181. def __init__(self):
  182. super(HomoBoostingArbiter, self).__init__()
  183. self.transfer_inst = HomoBoostingTransferVariable()
  184. self.check_convergence_func = None
  185. self.aggregator = None
  186. self.binning_obj = None
  187. def federated_binning(self, ):
  188. binning_param = HomoFeatureBinningParam(method=consts.RECURSIVE_QUERY, bin_num=self.bin_num,
  189. error=self.binning_error)
  190. if self.use_missing:
  191. self.binning_obj = recursive_query_binning.Server(binning_param, abnormal_list=[NoneType()])
  192. else:
  193. self.binning_obj = recursive_query_binning.Server(binning_param, abnormal_list=[])
  194. self.binning_obj.fit_split_points(None)
  195. def sync_feature_num(self):
  196. feature_num_list = self.transfer_inst.feature_number.get(idx=-1, suffix=('feat_num',))
  197. for num in feature_num_list[1:]:
  198. assert feature_num_list[0] == num
  199. return feature_num_list[0]
  200. def sync_start_round_and_end_round(self):
  201. r_list = self.transfer_inst.start_and_end_round.get(-1)
  202. LOGGER.info('get start/end round from clients: {}'.format(r_list))
  203. self.start_round, self.boosting_round = r_list[0]
  204. def check_label(self):
  205. pass
  206. def fit(self, data_inst, validate_data=None):
  207. # init binning obj
  208. self.aggregator = HomoBoostArbiterAggregator()
  209. self.federated_binning()
  210. # initializing
  211. self.feature_num = self.sync_feature_num()
  212. if self.task_type == consts.CLASSIFICATION:
  213. label_mapping = HomoLabelEncoderArbiter().label_alignment()
  214. LOGGER.info('label mapping is {}'.format(label_mapping))
  215. self.booster_dim = len(label_mapping) if len(label_mapping) > 2 else 1
  216. if self.n_iter_no_change:
  217. self.check_convergence_func = converge_func_factory("diff", self.tol)
  218. # sync start round and end round
  219. self.sync_start_round_and_end_round()
  220. LOGGER.info('begin to fit a boosting tree')
  221. self.preprocess()
  222. for epoch_idx in range(self.start_round, self.boosting_round):
  223. LOGGER.info('cur epoch idx is {}'.format(epoch_idx))
  224. for class_idx in range(self.booster_dim):
  225. model = self.fit_a_learner(epoch_idx, class_idx)
  226. global_loss = self.aggregator.aggregate_loss(suffix=(epoch_idx,))
  227. self.history_loss.append(global_loss)
  228. LOGGER.debug('cur epoch global loss is {}'.format(global_loss))
  229. self.callback_metric("loss",
  230. "train",
  231. [Metric(epoch_idx, global_loss)])
  232. if self.n_iter_no_change:
  233. should_stop = self.aggregator.broadcast_converge_status(self.check_convergence, (global_loss,),
  234. suffix=(epoch_idx,))
  235. LOGGER.debug('stop flag sent')
  236. if should_stop:
  237. break
  238. self.callback_meta("loss",
  239. "train",
  240. MetricMeta(name="train",
  241. metric_type="LOSS",
  242. extra_metas={"Best": min(self.history_loss)}))
  243. self.postprocess()
  244. self.callback_list.on_train_end()
  245. self.set_summary(self.generate_summary())
  246. def predict(self, data_inst=None):
  247. LOGGER.debug('arbiter skip prediction')
  248. @abc.abstractmethod
  249. def fit_a_learner(self, epoch_idx: int, booster_dim: int):
  250. raise NotImplementedError()
  251. @abc.abstractmethod
  252. def load_learner(self, model_meta, model_param, epoch_idx, booster_idx):
  253. raise NotImplementedError()