tree_model_io.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from federatedml.param.boosting_param import DecisionTreeParam
  2. from federatedml.ensemble.basic_algorithms import HeteroFastDecisionTreeGuest, HeteroFastDecisionTreeHost, \
  3. HeteroDecisionTreeGuest, HeteroDecisionTreeHost
  4. from federatedml.util import consts
  5. def produce_hetero_tree_learner(role, tree_param: DecisionTreeParam, flow_id, data_bin, bin_split_points,
  6. bin_sparse_points, task_type, valid_features, host_party_list,
  7. runtime_idx,
  8. cipher_compress=True,
  9. mo_tree=False,
  10. class_num=1,
  11. g_h=None, encrypter=None, # guest only
  12. goss_subsample=False, complete_secure=False,
  13. max_sample_weights=1.0,
  14. objective=None,
  15. bin_num=None, # host only
  16. fast_sbt=False,
  17. tree_type=None, target_host_id=None, # fast sbt only
  18. guest_depth=2, host_depth=3 # fast sbt only
  19. ):
  20. if role == consts.GUEST:
  21. if not fast_sbt:
  22. tree = HeteroDecisionTreeGuest(tree_param)
  23. else:
  24. tree = HeteroFastDecisionTreeGuest(tree_param)
  25. tree.set_tree_work_mode(tree_type, target_host_id)
  26. tree.set_layered_depth(guest_depth, host_depth)
  27. tree.init(flowid=flow_id,
  28. data_bin=data_bin,
  29. bin_split_points=bin_split_points,
  30. bin_sparse_points=bin_sparse_points,
  31. grad_and_hess=g_h,
  32. encrypter=encrypter,
  33. task_type=task_type,
  34. valid_features=valid_features,
  35. host_party_list=host_party_list,
  36. runtime_idx=runtime_idx,
  37. goss_subsample=goss_subsample,
  38. complete_secure=complete_secure,
  39. cipher_compressing=cipher_compress,
  40. max_sample_weight=max_sample_weights,
  41. mo_tree=mo_tree,
  42. class_num=class_num,
  43. objective=objective
  44. )
  45. elif role == consts.HOST:
  46. if not fast_sbt:
  47. tree = HeteroDecisionTreeHost(tree_param)
  48. else:
  49. tree = HeteroFastDecisionTreeHost(tree_param)
  50. tree.set_tree_work_mode(tree_type, target_host_id)
  51. tree.set_layered_depth(guest_depth, host_depth)
  52. tree.set_self_host_id(runtime_idx)
  53. tree.set_host_party_idlist(host_party_list)
  54. tree.init(flowid=flow_id,
  55. valid_features=valid_features,
  56. data_bin=data_bin,
  57. bin_split_points=bin_split_points,
  58. bin_sparse_points=bin_sparse_points,
  59. runtime_idx=runtime_idx,
  60. goss_subsample=goss_subsample,
  61. complete_secure=complete_secure,
  62. cipher_compressing=cipher_compress,
  63. bin_num=bin_num,
  64. mo_tree=mo_tree
  65. )
  66. else:
  67. raise ValueError('unknown role: {}'.format(role))
  68. return tree
  69. def load_hetero_tree_learner(role, tree_param, model_meta, model_param, flow_id, runtime_idx, host_party_list=None,
  70. fast_sbt=False, tree_type=None, target_host_id=None):
  71. if role == consts.HOST:
  72. if fast_sbt:
  73. tree = HeteroFastDecisionTreeHost(tree_param)
  74. else:
  75. tree = HeteroDecisionTreeHost(tree_param)
  76. tree.load_model(model_meta, model_param)
  77. tree.set_flowid(flow_id)
  78. tree.set_runtime_idx(runtime_idx)
  79. if fast_sbt:
  80. tree.set_tree_work_mode(tree_type, target_host_id)
  81. tree.set_self_host_id(runtime_idx)
  82. elif role == consts.GUEST:
  83. if fast_sbt:
  84. tree = HeteroFastDecisionTreeGuest(tree_param)
  85. else:
  86. tree = HeteroDecisionTreeGuest(tree_param)
  87. tree.load_model(model_meta, model_param)
  88. tree.set_flowid(flow_id)
  89. tree.set_runtime_idx(runtime_idx)
  90. tree.set_host_party_idlist(host_party_list)
  91. if fast_sbt:
  92. tree.set_tree_work_mode(tree_type, target_host_id)
  93. else:
  94. raise ValueError('unknown role: {}'.format(role))
  95. return tree