hetero_boosting.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from abc import ABC
  19. import abc
  20. from federatedml.ensemble.boosting import Boosting
  21. from federatedml.param.boosting_param import HeteroBoostingParam
  22. from federatedml.secureprotol import PaillierEncrypt, IpclPaillierEncrypt
  23. from federatedml.util import consts
  24. from federatedml.feature.binning.quantile_binning import QuantileBinning
  25. from federatedml.util.classify_label_checker import ClassifyLabelChecker
  26. from federatedml.util.classify_label_checker import RegressionLabelChecker
  27. from federatedml.util import LOGGER
  28. from federatedml.model_base import Metric
  29. from federatedml.model_base import MetricMeta
  30. from federatedml.transfer_variable.transfer_class.hetero_boosting_transfer_variable import \
  31. HeteroBoostingTransferVariable
  32. from federatedml.util.io_check import assert_io_num_rows_equal
  33. from federatedml.statistic.data_overview import get_anonymous_header
  34. class HeteroBoosting(Boosting, ABC):
  35. def __init__(self):
  36. super(HeteroBoosting, self).__init__()
  37. self.encrypter = None
  38. self.early_stopping_rounds = None
  39. self.binning_class = QuantileBinning
  40. self.model_param = HeteroBoostingParam()
  41. self.transfer_variable = HeteroBoostingTransferVariable()
  42. self.mode = consts.HETERO
  43. def _init_model(self, param: HeteroBoostingParam):
  44. LOGGER.debug('in hetero boosting, objective param is {}'.format(param.objective_param.objective))
  45. super(HeteroBoosting, self)._init_model(param)
  46. self.encrypt_param = param.encrypt_param
  47. self.early_stopping_rounds = param.early_stopping_rounds
  48. self.use_first_metric_only = param.use_first_metric_only
  49. class HeteroBoostingGuest(HeteroBoosting, ABC):
  50. def __init__(self):
  51. super(HeteroBoostingGuest, self).__init__()
  52. def _init_model(self, param):
  53. super(HeteroBoostingGuest, self)._init_model(param)
  54. def generate_encrypter(self):
  55. LOGGER.info("generate encrypter")
  56. if self.encrypt_param.method.lower() == consts.PAILLIER.lower():
  57. self.encrypter = PaillierEncrypt()
  58. self.encrypter.generate_key(self.encrypt_param.key_length)
  59. elif self.encrypt_param.method.lower() == consts.PAILLIER_IPCL.lower():
  60. self.encrypter = IpclPaillierEncrypt()
  61. self.encrypter.generate_key(self.encrypt_param.key_length)
  62. else:
  63. raise NotImplementedError("unknown encrypt type {}".format(self.encrypt_param.method.lower()))
  64. def check_label(self):
  65. LOGGER.info("check label")
  66. classes_ = []
  67. num_classes, booster_dim = 1, 1
  68. if self.task_type == consts.CLASSIFICATION:
  69. num_classes, classes_ = ClassifyLabelChecker.validate_label(self.data_bin)
  70. if num_classes > 2:
  71. booster_dim = num_classes
  72. range_from_zero = True
  73. for _class in classes_:
  74. try:
  75. if 0 <= _class < len(classes_) and isinstance(_class, int):
  76. continue
  77. else:
  78. range_from_zero = False
  79. break
  80. except BaseException:
  81. range_from_zero = False
  82. classes_ = sorted(classes_)
  83. if not range_from_zero:
  84. class_mapping = dict(zip(classes_, range(num_classes)))
  85. self.y = self.y.mapValues(lambda _class: class_mapping[_class])
  86. else:
  87. RegressionLabelChecker.validate_label(self.data_bin)
  88. return classes_, num_classes, booster_dim
  89. def sync_booster_dim(self):
  90. LOGGER.info("sync booster_dim to host")
  91. self.transfer_variable.booster_dim.remote(self.booster_dim,
  92. role=consts.HOST,
  93. idx=-1)
  94. def sync_stop_flag(self, stop_flag, num_round):
  95. LOGGER.info("sync stop flag to host, boosting_core round is {}".format(num_round))
  96. self.transfer_variable.stop_flag.remote(stop_flag,
  97. role=consts.HOST,
  98. idx=-1,
  99. suffix=(num_round,))
  100. def sync_predict_round(self, predict_round, ):
  101. LOGGER.info("sync predict start round {}".format(predict_round))
  102. self.transfer_variable.predict_start_round.remote(predict_round, role=consts.HOST, idx=-1, )
  103. def prepare_warm_start(self, data_inst, classes):
  104. # adjust parameter for warm start
  105. warm_start_y_hat = self.predict(data_inst, ret_format='raw')
  106. self.y_hat = warm_start_y_hat
  107. self.start_round = len(self.boosting_model_list) // self.booster_dim
  108. self.boosting_round += self.start_round
  109. # check classes
  110. assert set(classes).issubset(set(self.classes_)), 'warm start label alignment failed: cur labels {},' \
  111. 'previous model labels {}'.format(classes, self.classes_)
  112. # check fid
  113. self.feat_name_check(data_inst, self.feature_name_fid_mapping)
  114. self.callback_warm_start_init_iter(self.start_round)
  115. def fit(self, data_inst, validate_data=None):
  116. LOGGER.info('begin to fit a hetero boosting model, model is {}'.format(self.model_name))
  117. self.start_round = 0
  118. self.on_training = True
  119. self.data_inst = data_inst
  120. to_process_data_inst = self.data_and_header_alignment(data_inst) if self.is_warm_start else data_inst
  121. self.data_bin, self.bin_split_points, self.bin_sparse_points = self.prepare_data(to_process_data_inst)
  122. self.y = self.get_label(self.data_bin)
  123. if not self.is_warm_start:
  124. self.feature_name_fid_mapping = self.gen_feature_fid_mapping(data_inst.schema)
  125. self.classes_, self.num_classes, self.booster_dim = self.check_label()
  126. self.loss = self.get_loss_function()
  127. self.y_hat, self.init_score = self.get_init_score(self.y, self.num_classes)
  128. else:
  129. classes_, num_classes, booster_dim = self.check_label()
  130. self.prepare_warm_start(data_inst, classes_)
  131. LOGGER.info('class index is {}'.format(self.classes_))
  132. self.sync_booster_dim()
  133. self.generate_encrypter()
  134. self.callback_list.on_train_begin(data_inst, validate_data)
  135. self.callback_meta("loss",
  136. "train",
  137. MetricMeta(name="train",
  138. metric_type="LOSS",
  139. extra_metas={"unit_name": "iters"}))
  140. self.preprocess()
  141. for epoch_idx in range(self.start_round, self.boosting_round):
  142. LOGGER.info('cur epoch idx is {}'.format(epoch_idx))
  143. self.callback_list.on_epoch_begin(epoch_idx)
  144. for class_idx in range(self.booster_dim):
  145. # fit a booster
  146. model = self.fit_a_learner(epoch_idx, class_idx)
  147. booster_meta, booster_param = model.get_model()
  148. if booster_meta is not None and booster_param is not None:
  149. self.booster_meta = booster_meta
  150. self.boosting_model_list.append(booster_param)
  151. # update predict score
  152. cur_sample_weights = model.get_sample_weights()
  153. self.y_hat = self.get_new_predict_score(self.y_hat, cur_sample_weights, dim=class_idx)
  154. # compute loss
  155. loss = self.compute_loss(self.y_hat, self.y)
  156. self.history_loss.append(loss)
  157. LOGGER.info("round {} loss is {}".format(epoch_idx, loss))
  158. self.callback_metric("loss",
  159. "train",
  160. [Metric(epoch_idx, loss)])
  161. # check validation
  162. validation_strategy = self.callback_list.get_validation_strategy()
  163. if validation_strategy:
  164. validation_strategy.set_precomputed_train_scores(self.score_to_predict_result(data_inst, self.y_hat))
  165. self.callback_list.on_epoch_end(epoch_idx)
  166. should_stop = False
  167. if self.n_iter_no_change and self.check_convergence(loss):
  168. should_stop = True
  169. self.is_converged = True
  170. self.sync_stop_flag(self.is_converged, epoch_idx)
  171. if self.stop_training or should_stop:
  172. break
  173. self.postprocess()
  174. self.callback_list.on_train_end()
  175. self.callback_meta("loss",
  176. "train",
  177. MetricMeta(name="train",
  178. metric_type="LOSS",
  179. extra_metas={"Best": min(self.history_loss)}))
  180. # get summary
  181. self.set_summary(self.generate_summary())
  182. @assert_io_num_rows_equal
  183. def predict(self, data_inst):
  184. # predict is implemented in hetero_secureboost
  185. raise NotImplementedError('predict func is not implemented')
  186. @abc.abstractmethod
  187. def fit_a_learner(self, epoch_idx: int, booster_dim: int):
  188. raise NotImplementedError()
  189. @abc.abstractmethod
  190. def load_learner(self, model_meta, model_param, epoch_idx, booster_idx):
  191. raise NotImplementedError()
  192. @abc.abstractmethod
  193. def get_model_meta(self):
  194. raise NotImplementedError()
  195. @abc.abstractmethod
  196. def get_model_param(self):
  197. raise NotImplementedError()
  198. @abc.abstractmethod
  199. def set_model_meta(self, model_meta):
  200. raise NotImplementedError()
  201. @abc.abstractmethod
  202. def set_model_param(self, model_param):
  203. raise NotImplementedError()
  204. class HeteroBoostingHost(HeteroBoosting, ABC):
  205. def __init__(self):
  206. super(HeteroBoostingHost, self).__init__()
  207. def _init_model(self, param):
  208. super(HeteroBoostingHost, self)._init_model(param)
  209. def sync_booster_dim(self):
  210. LOGGER.info("sync booster dim from guest")
  211. self.booster_dim = self.transfer_variable.booster_dim.get(idx=0)
  212. LOGGER.info("booster dim is %d" % self.booster_dim)
  213. def sync_stop_flag(self, num_round):
  214. LOGGER.info("sync stop flag from guest, boosting_core round is {}".format(num_round))
  215. stop_flag = self.transfer_variable.stop_flag.get(idx=0,
  216. suffix=(num_round,))
  217. return stop_flag
  218. def sync_predict_start_round(self, ):
  219. return self.transfer_variable.predict_start_round.get(idx=0, )
  220. def prepare_warm_start(self, data_inst):
  221. self.predict(data_inst)
  222. self.callback_warm_start_init_iter(self.start_round)
  223. self.feat_name_check(data_inst, self.feature_name_fid_mapping)
  224. self.start_round = len(self.boosting_model_list) // self.booster_dim
  225. self.boosting_round += self.start_round
  226. def set_anonymous_header(self, data_inst):
  227. if not self.anonymous_header:
  228. self.anonymous_header = {v: k for k, v in zip(get_anonymous_header(data_inst), data_inst.schema['header'])}
  229. def fit(self, data_inst, validate_data=None):
  230. LOGGER.info('begin to fit a hetero boosting model, model is {}'.format(self.model_name))
  231. self.start_round = 0
  232. self.on_training = True
  233. to_process_data_inst = self.data_and_header_alignment(data_inst) if self.is_warm_start else data_inst
  234. self.data_bin, self.bin_split_points, self.bin_sparse_points = self.prepare_data(to_process_data_inst)
  235. self.set_anonymous_header(to_process_data_inst)
  236. if self.is_warm_start:
  237. self.prepare_warm_start(data_inst)
  238. else:
  239. self.feature_name_fid_mapping = self.gen_feature_fid_mapping(data_inst.schema)
  240. self.sync_booster_dim()
  241. self.callback_list.on_train_begin(data_inst, validate_data)
  242. self.preprocess()
  243. for epoch_idx in range(self.start_round, self.boosting_round):
  244. LOGGER.info('cur epoch idx is {}'.format(epoch_idx))
  245. self.callback_list.on_epoch_begin(epoch_idx)
  246. for class_idx in range(self.booster_dim):
  247. # fit a booster
  248. model = self.fit_a_learner(epoch_idx, class_idx) # need to implement
  249. booster_meta, booster_param = model.get_model()
  250. if booster_meta is not None and booster_param is not None:
  251. self.booster_meta = booster_meta
  252. self.boosting_model_list.append(booster_param)
  253. validation_strategy = self.callback_list.get_validation_strategy()
  254. if validation_strategy:
  255. validation_strategy.set_precomputed_train_scores(None)
  256. self.callback_list.on_epoch_end(epoch_idx)
  257. should_stop = self.sync_stop_flag(epoch_idx)
  258. self.is_converged = should_stop
  259. if should_stop or self.stop_training:
  260. break
  261. self.postprocess()
  262. self.callback_list.on_train_end()
  263. self.set_summary(self.generate_summary())
  264. def lazy_predict(self, data_inst):
  265. LOGGER.info('running guest lazy prediction')
  266. data_inst = self.data_alignment(data_inst)
  267. init_score = self.init_score
  268. self.predict_y_hat = data_inst.mapValues(lambda v: init_score)
  269. rounds = len(self.boosting_model_list) // self.booster_dim
  270. predict_start_round = self.sync_predict_start_round()
  271. for idx in range(predict_start_round, rounds):
  272. for booster_idx in range(self.booster_dim):
  273. model = self.load_learner(self.booster_meta,
  274. self.boosting_model_list[idx * self.booster_dim + booster_idx],
  275. idx, booster_idx)
  276. model.predict(data_inst)
  277. LOGGER.debug('lazy prediction finished')
  278. def predict(self, data_inst):
  279. LOGGER.info('using default lazy prediction')
  280. self.lazy_predict(data_inst)
  281. @abc.abstractmethod
  282. def load_learner(self, model_meta, model_param, epoch_idx, booster_idx):
  283. raise NotImplementedError()
  284. @abc.abstractmethod
  285. def fit_a_learner(self, epoch_idx: int, booster_dim: int):
  286. raise NotImplementedError()
  287. @abc.abstractmethod
  288. def get_model_meta(self):
  289. raise NotImplementedError()
  290. @abc.abstractmethod
  291. def get_model_param(self):
  292. raise NotImplementedError()
  293. @abc.abstractmethod
  294. def set_model_meta(self, model_meta):
  295. raise NotImplementedError()
  296. @abc.abstractmethod
  297. def set_model_param(self, model_param):
  298. raise NotImplementedError()