homo_secureboost_arbiter.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import numpy as np
  2. from numpy import random
  3. from federatedml.util import LOGGER
  4. from federatedml.util import consts
  5. from federatedml.ensemble.boosting.homo_boosting import HomoBoostingArbiter
  6. from federatedml.param.boosting_param import HomoSecureBoostParam
  7. from federatedml.ensemble.basic_algorithms.decision_tree.homo.homo_decision_tree_arbiter import HomoDecisionTreeArbiter
  8. class HomoSecureBoostingTreeArbiter(HomoBoostingArbiter):
  9. def __init__(self):
  10. super(HomoSecureBoostingTreeArbiter, self).__init__()
  11. self.model_name = 'HomoSecureBoost'
  12. self.tree_param = None # decision tree param
  13. self.use_missing = False
  14. self.zero_as_missing = False
  15. self.cur_epoch_idx = -1
  16. self.grad_and_hess = None
  17. self.feature_importances_ = {}
  18. self.model_param = HomoSecureBoostParam()
  19. self.multi_mode = consts.SINGLE_OUTPUT
  20. def _init_model(self, boosting_param: HomoSecureBoostParam):
  21. super(HomoSecureBoostingTreeArbiter, self)._init_model(boosting_param)
  22. self.use_missing = boosting_param.use_missing
  23. self.zero_as_missing = boosting_param.zero_as_missing
  24. self.tree_param = boosting_param.tree_param
  25. self.multi_mode = boosting_param.multi_mode
  26. if self.use_missing:
  27. self.tree_param.use_missing = self.use_missing
  28. self.tree_param.zero_as_missing = self.zero_as_missing
  29. def send_valid_features(self, valid_features, epoch_idx, b_idx):
  30. self.transfer_inst.valid_features.remote(valid_features, idx=-1, suffix=('valid_features', epoch_idx, b_idx))
  31. def sample_valid_features(self):
  32. LOGGER.info("sample valid features")
  33. chosen_feature = random.choice(range(0, self.feature_num),
  34. max(1, int(self.subsample_feature_rate * self.feature_num)), replace=False)
  35. valid_features = [False for i in range(self.feature_num)]
  36. for fid in chosen_feature:
  37. valid_features[fid] = True
  38. return valid_features
  39. def preprocess(self):
  40. if self.multi_mode == consts.MULTI_OUTPUT:
  41. self.booster_dim = 1
  42. def fit_a_learner(self, epoch_idx: int, booster_dim: int):
  43. valid_feature = self.sample_valid_features()
  44. self.send_valid_features(valid_feature, epoch_idx, booster_dim)
  45. flow_id = self.generate_flowid(epoch_idx, booster_dim)
  46. new_tree = HomoDecisionTreeArbiter(self.tree_param, valid_feature=valid_feature, epoch_idx=epoch_idx,
  47. flow_id=flow_id, tree_idx=booster_dim)
  48. new_tree.fit()
  49. return new_tree
  50. def generate_summary(self) -> dict:
  51. summary = {'loss_history': self.history_loss}
  52. return summary
  53. # homo tree arbiter doesnt save model
  54. def get_cur_model(self):
  55. return None
  56. def load_learner(self, model_meta, model_param, epoch_idx, booster_idx):
  57. pass
  58. def set_model_param(self, model_param):
  59. pass
  60. def set_model_meta(self, model_meta):
  61. pass
  62. def get_model_param(self):
  63. pass
  64. def get_model_meta(self):
  65. pass