123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- from federatedml.param.boosting_param import DecisionTreeParam
- from federatedml.ensemble.basic_algorithms import HeteroFastDecisionTreeGuest, HeteroFastDecisionTreeHost, \
- HeteroDecisionTreeGuest, HeteroDecisionTreeHost
- from federatedml.util import consts
- def produce_hetero_tree_learner(role, tree_param: DecisionTreeParam, flow_id, data_bin, bin_split_points,
- bin_sparse_points, task_type, valid_features, host_party_list,
- runtime_idx,
- cipher_compress=True,
- mo_tree=False,
- class_num=1,
- g_h=None, encrypter=None, # guest only
- goss_subsample=False, complete_secure=False,
- max_sample_weights=1.0,
- objective=None,
- bin_num=None, # host only
- fast_sbt=False,
- tree_type=None, target_host_id=None, # fast sbt only
- guest_depth=2, host_depth=3 # fast sbt only
- ):
- if role == consts.GUEST:
- if not fast_sbt:
- tree = HeteroDecisionTreeGuest(tree_param)
- else:
- tree = HeteroFastDecisionTreeGuest(tree_param)
- tree.set_tree_work_mode(tree_type, target_host_id)
- tree.set_layered_depth(guest_depth, host_depth)
- tree.init(flowid=flow_id,
- data_bin=data_bin,
- bin_split_points=bin_split_points,
- bin_sparse_points=bin_sparse_points,
- grad_and_hess=g_h,
- encrypter=encrypter,
- task_type=task_type,
- valid_features=valid_features,
- host_party_list=host_party_list,
- runtime_idx=runtime_idx,
- goss_subsample=goss_subsample,
- complete_secure=complete_secure,
- cipher_compressing=cipher_compress,
- max_sample_weight=max_sample_weights,
- mo_tree=mo_tree,
- class_num=class_num,
- objective=objective
- )
- elif role == consts.HOST:
- if not fast_sbt:
- tree = HeteroDecisionTreeHost(tree_param)
- else:
- tree = HeteroFastDecisionTreeHost(tree_param)
- tree.set_tree_work_mode(tree_type, target_host_id)
- tree.set_layered_depth(guest_depth, host_depth)
- tree.set_self_host_id(runtime_idx)
- tree.set_host_party_idlist(host_party_list)
- tree.init(flowid=flow_id,
- valid_features=valid_features,
- data_bin=data_bin,
- bin_split_points=bin_split_points,
- bin_sparse_points=bin_sparse_points,
- runtime_idx=runtime_idx,
- goss_subsample=goss_subsample,
- complete_secure=complete_secure,
- cipher_compressing=cipher_compress,
- bin_num=bin_num,
- mo_tree=mo_tree
- )
- else:
- raise ValueError('unknown role: {}'.format(role))
- return tree
- def load_hetero_tree_learner(role, tree_param, model_meta, model_param, flow_id, runtime_idx, host_party_list=None,
- fast_sbt=False, tree_type=None, target_host_id=None):
- if role == consts.HOST:
- if fast_sbt:
- tree = HeteroFastDecisionTreeHost(tree_param)
- else:
- tree = HeteroDecisionTreeHost(tree_param)
- tree.load_model(model_meta, model_param)
- tree.set_flowid(flow_id)
- tree.set_runtime_idx(runtime_idx)
- if fast_sbt:
- tree.set_tree_work_mode(tree_type, target_host_id)
- tree.set_self_host_id(runtime_idx)
- elif role == consts.GUEST:
- if fast_sbt:
- tree = HeteroFastDecisionTreeGuest(tree_param)
- else:
- tree = HeteroDecisionTreeGuest(tree_param)
- tree.load_model(model_meta, model_param)
- tree.set_flowid(flow_id)
- tree.set_runtime_idx(runtime_idx)
- tree.set_host_party_idlist(host_party_list)
- if fast_sbt:
- tree.set_tree_work_mode(tree_type, target_host_id)
- else:
- raise ValueError('unknown role: {}'.format(role))
- return tree
|