gbdt.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import numpy as np
  2. import lightgbm as lgb
  3. from ..component_converter import ComponentConverterBase
  4. from federatedml.protobuf.generated.boosting_tree_model_meta_pb2 import BoostingTreeModelMeta
  5. from federatedml.protobuf.generated.boosting_tree_model_param_pb2 import BoostingTreeModelParam, \
  6. DecisionTreeModelParam, NodeParam
  7. from federatedml.util import consts
  8. from federatedml.util import LOGGER
  9. """
  10. We only keep the necessary variable to make sure that lightgbm can run predict function on the converted model
  11. """
  12. FAKE_FEATURE_INFO_STR = '[0:1] '
  13. END_OF_TREE = 'end of trees'
  14. END_OF_PARA = 'end of parameters'
  15. SPLIT = '\n\n'
  16. HEADER_TEMPLATE = """tree
  17. version=v3
  18. num_class={}
  19. num_tree_per_iteration={}
  20. label_index={}
  21. max_feature_idx={}
  22. objective={}
  23. feature_names={}
  24. feature_infos={}
  25. """
  26. TREE_TEMPLATE = """Tree={}
  27. num_leaves={}
  28. num_cat={}
  29. split_feature={}
  30. threshold={}
  31. decision_type={}
  32. left_child={}
  33. right_child={}
  34. leaf_value={}
  35. internal_value={}
  36. shrinkage={}
  37. """
  38. PARA_TEMPLATE = """parameters:
  39. [boosting: gbdt]
  40. [objective: {}]
  41. [num_iterations: {}]
  42. [learning_rate: {}]
  43. [max_depth: {}]
  44. [max_bin: {}]
  45. [use_missing: {}]
  46. [zero_as_missing: {}]
  47. [num_class: {}]
  48. [lambda_l1: {}]
  49. [lambda_l2: {}]
  50. [min_data_in_leaf: {}]
  51. [min_gain_to_split: {}]
  52. """
  53. LGB_OBJECTIVE = {
  54. consts.BINARY: "binary sigmoid:1",
  55. consts.REGRESSION: "regression",
  56. consts.MULTY: 'multiclass num_class:{}'
  57. }
  58. PARA_OBJECTIVE = {
  59. consts.BINARY: "binary",
  60. consts.REGRESSION: "regression",
  61. consts.MULTY: 'multiclass'
  62. }
  63. def get_decision_type(node: NodeParam, use_missing, zero_as_missing):
  64. # 00 0 0
  65. # Nan,0 or None default left or right? cat feature or not?
  66. default_type = 0 # 0000 None, default right, not cat feat
  67. if not use_missing:
  68. return default_type
  69. if node.missing_dir == -1:
  70. default_type = default_type | 2 # 0010
  71. if zero_as_missing:
  72. default_type = default_type | 4 # 0100 0
  73. else:
  74. default_type = default_type | 8 # 1000 np.Nan
  75. return default_type
  76. def get_lgb_objective(task_type, num_classes, ret_dict, need_multi_format=True):
  77. if task_type == consts.CLASSIFICATION:
  78. if num_classes == 1:
  79. objective = ret_dict[consts.BINARY]
  80. else:
  81. objective = ret_dict[consts.MULTY].format(num_classes) if need_multi_format else ret_dict[consts.MULTY]
  82. else:
  83. objective = ret_dict[consts.REGRESSION]
  84. return objective
  85. def list_to_str(l_):
  86. return str(l_).replace('[', '').replace(']', '').replace(',', '')
  87. def parse_header(param: BoostingTreeModelParam, meta: BoostingTreeModelMeta):
  88. # generated header of lgb str model file
  89. # binary/regression num class is 1 in lgb
  90. num_classes = len(param.classes_) if len(param.classes_) > 2 else 1
  91. objective = get_lgb_objective(meta.task_type, num_classes, LGB_OBJECTIVE, need_multi_format=True)
  92. num_tree_per_iteration = param.tree_dim
  93. label_index = 0 # by default
  94. max_feature_idx = len(param.feature_name_fid_mapping) - 1
  95. feature_names = ''
  96. for name in [param.feature_name_fid_mapping[i] for i in range(max_feature_idx + 1)]:
  97. if ' ' in name: # space is not allowed
  98. name = name.replace(' ', '-')
  99. feature_names += name + ' '
  100. feature_names = feature_names[:-1]
  101. feature_info = FAKE_FEATURE_INFO_STR * (max_feature_idx + 1) # need to make fake feature info
  102. feature_info = feature_info[:-1]
  103. result_str = HEADER_TEMPLATE.format(num_classes, num_tree_per_iteration, label_index, max_feature_idx,
  104. objective, feature_names, feature_info)
  105. return result_str
  106. def internal_count_computer(cur_id, tree_node, leaf_count, internal_count):
  107. if cur_id in leaf_count:
  108. return leaf_count[cur_id]
  109. left_count = internal_count_computer(tree_node[cur_id].left_nodeid, tree_node, leaf_count, internal_count)
  110. right_count = internal_count_computer(tree_node[cur_id].right_nodeid, tree_node, leaf_count, internal_count)
  111. internal_count[cur_id] = left_count + right_count
  112. return internal_count[cur_id]
  113. def compute_internal_count(tree_param: DecisionTreeModelParam):
  114. root = tree_param.tree_[0]
  115. internal_count = {}
  116. leaf_count = tree_param.leaf_count
  117. root_count = internal_count_computer(root.id, tree_param.tree_, leaf_count, internal_count)
  118. if root.id not in internal_count:
  119. internal_count[root_count] = root_count
  120. return internal_count
  121. def update_leaf_count(param):
  122. # in homo sbt, sometimes a leaf covers no sample, so need to add 1 to leaf count
  123. tmp = {}
  124. for i in param.leaf_count:
  125. tmp[i] = param.leaf_count[i]
  126. for i in tmp:
  127. if tmp[i] == 0:
  128. param.leaf_count[i] += 1
  129. def parse_a_tree(
  130. param: DecisionTreeModelParam,
  131. tree_idx: int,
  132. use_missing=False,
  133. zero_as_missing=False,
  134. learning_rate=0.1,
  135. init_score=None):
  136. split_feature = []
  137. split_threshold = []
  138. decision_type = []
  139. internal_weight = []
  140. leaf_weight = []
  141. left, right = [], []
  142. leaf_idx = -1
  143. lgb_node_idx = 0
  144. sbt_lgb_node_map = {}
  145. is_leaf = []
  146. leaf_count = []
  147. internal_count, internal_count_dict = [], {}
  148. has_count_info = len(param.leaf_count) != 0
  149. # compute internal count
  150. if has_count_info:
  151. update_leaf_count(param)
  152. internal_count_dict = compute_internal_count(param) # get internal count from leaf count
  153. # mark leaf nodes and get sbt-lgb node mapping
  154. for node in param.tree_:
  155. is_leaf.append(node.is_leaf)
  156. if not node.is_leaf:
  157. sbt_lgb_node_map[node.id] = lgb_node_idx
  158. lgb_node_idx += 1
  159. for cur_idx, node in enumerate(param.tree_):
  160. if not node.is_leaf:
  161. split_feature.append(node.fid)
  162. # if is hetero model need to decode split point and missing dir
  163. if param.split_maskdict and param.missing_dir_maskdict is not None:
  164. node.bid = param.split_maskdict[node.id]
  165. node.missing_dir = param.missing_dir_maskdict[node.id]
  166. # extract split point and weight
  167. split_threshold.append(node.bid)
  168. internal_weight.append(node.weight)
  169. # add internal count
  170. if has_count_info:
  171. internal_count.append(internal_count_dict[node.id])
  172. if is_leaf[node.left_nodeid]: # generate lgb leaf idx
  173. left.append(leaf_idx)
  174. if has_count_info:
  175. leaf_count.append(param.leaf_count[node.left_nodeid])
  176. leaf_idx -= 1
  177. else:
  178. left.append(sbt_lgb_node_map[node.left_nodeid])
  179. if is_leaf[node.right_nodeid]: # generate lgb leaf idx
  180. right.append(leaf_idx)
  181. if has_count_info:
  182. leaf_count.append(param.leaf_count[node.right_nodeid])
  183. leaf_idx -= 1
  184. else:
  185. right.append(sbt_lgb_node_map[node.right_nodeid])
  186. # get lgb decision type
  187. decision_type.append(get_decision_type(node, use_missing, zero_as_missing))
  188. else:
  189. # regression model need to add init score
  190. if init_score is not None:
  191. score = node.weight * learning_rate + init_score
  192. else:
  193. # leaf value is node.weight * learning_rate in lgb
  194. score = node.weight * learning_rate
  195. leaf_weight.append(score)
  196. leaves_num = len(leaf_weight)
  197. num_cat = 0
  198. # to string
  199. result_str = TREE_TEMPLATE.format(tree_idx, leaves_num, num_cat, list_to_str(split_feature),
  200. list_to_str(split_threshold), list_to_str(decision_type),
  201. list_to_str(left), list_to_str(right), list_to_str(leaf_weight),
  202. list_to_str(internal_weight), learning_rate)
  203. if len(internal_count) != 0:
  204. result_str += 'internal_count={}\n'.format(list_to_str(internal_count))
  205. if len(leaf_count) != 0:
  206. result_str += 'leaf_count={}\n'.format(list_to_str(leaf_count))
  207. return result_str
  208. def parse_feature_importance(param):
  209. feat_importance_str = "feature_importances:\n"
  210. mapping = param.feature_name_fid_mapping
  211. for impt in param.feature_importances:
  212. impt_val = impt.importance
  213. try:
  214. if impt.main == 'split':
  215. impt_val = int(impt_val)
  216. except BaseException:
  217. LOGGER.warning("old version protobuf contains no filed 'main'")
  218. feat_importance_str += '{}={}\n'.format(mapping[impt.fid], impt_val)
  219. return feat_importance_str
  220. def parse_parameter(param, meta):
  221. """
  222. we only keep parameters offered by SBT
  223. """
  224. tree_meta = meta.tree_meta
  225. num_classes = 1 if meta.task_type == consts.CLASSIFICATION and param.num_classes < 3 else param.num_classes
  226. objective = get_lgb_objective(meta.task_type, num_classes, PARA_OBJECTIVE, need_multi_format=False)
  227. rs = PARA_TEMPLATE.format(objective, meta.num_trees, meta.learning_rate, tree_meta.max_depth,
  228. meta.quantile_meta.bin_num, meta.tree_meta.use_missing + 0,
  229. meta.tree_meta.zero_as_missing + 0,
  230. num_classes, tree_meta.criterion_meta.criterion_param[0],
  231. tree_meta.criterion_meta.criterion_param[1],
  232. tree_meta.min_leaf_node,
  233. tree_meta.min_impurity_split
  234. )
  235. return rs
  236. def sbt_to_lgb(model_param: BoostingTreeModelParam,
  237. model_meta: BoostingTreeModelMeta,
  238. load_feature_importance=True):
  239. """
  240. Transform sbt model to lgb model
  241. """
  242. result = ''
  243. # parse header
  244. header_str = parse_header(model_param, model_meta)
  245. use_missing = model_meta.tree_meta.use_missing
  246. zero_as_missing = model_meta.tree_meta.zero_as_missing
  247. learning_rate = model_meta.learning_rate
  248. tree_str_list = []
  249. # parse tree
  250. for idx, param in enumerate(model_param.trees_):
  251. if idx == 0 and model_meta.task_type == consts.REGRESSION: # regression task has init score
  252. init_score = model_param.init_score[0]
  253. else:
  254. init_score = 0
  255. tree_str_list.append(parse_a_tree(param, idx, use_missing, zero_as_missing, learning_rate, init_score))
  256. # add header and tree str to result
  257. result += header_str + '\n'
  258. for s in tree_str_list:
  259. result += s
  260. result += SPLIT
  261. result += END_OF_TREE
  262. # handle feature importance
  263. if load_feature_importance:
  264. feat_importance_str = parse_feature_importance(model_param)
  265. result += SPLIT + feat_importance_str
  266. # parameters
  267. para_str = parse_parameter(model_param, model_meta)
  268. result += '\n' + para_str + '\n' + END_OF_PARA + '\n'
  269. result += '\npandas_categorical:[]\n'
  270. return result
  271. def save_lgb(model: lgb.Booster, path):
  272. model_str = model.model_to_string()
  273. f = open(path, 'w')
  274. f.write(model_str)
  275. f.close()
  276. def load_lgb(path):
  277. f = open(path, 'r')
  278. model_str = f.read()
  279. f.close()
  280. lgb_model = lgb.Booster(model_str=model_str)
  281. return lgb_model
  282. class HomoSBTComponentConverter(ComponentConverterBase):
  283. @staticmethod
  284. def get_target_modules():
  285. return ['HomoSecureboost']
  286. def convert(self, model_dict):
  287. param_obj = model_dict["HomoSecureBoostingTreeGuestParam"]
  288. meta_obj = model_dict["HomoSecureBoostingTreeGuestMeta"]
  289. lgb_model_str = sbt_to_lgb(param_obj, meta_obj)
  290. lgb_model = lgb.Booster(model_str=lgb_model_str)
  291. return lgb_model