merge_sbt.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import json
  2. import numpy as np
  3. import lightgbm as lgb
  4. from sklearn.pipeline import Pipeline
  5. from lightgbm.sklearn import _LGBMLabelEncoder
  6. from federatedml.protobuf.homo_model_convert.lightgbm.gbdt import sbt_to_lgb
  7. from federatedml.protobuf.generated.boosting_tree_model_param_pb2 import BoostingTreeModelParam
  8. from federatedml.protobuf.generated.boosting_tree_model_meta_pb2 import BoostingTreeModelMeta
  9. from google.protobuf import json_format
  10. from federatedml.util.anonymous_generator_util import Anonymous
  11. def _merge_sbt(guest_param, host_param, host_sitename, rename_host=True):
  12. # update feature name fid mapping
  13. guest_fid_map = guest_param['featureNameFidMapping']
  14. guest_fid_map = {int(k): v for k, v in guest_fid_map.items()}
  15. host_fid_map = sorted([(int(k), v) for k, v in host_param['featureNameFidMapping'].items()], key=lambda x: x[0])
  16. guest_feat_len = len(guest_fid_map)
  17. start = guest_feat_len
  18. host_new_fid = {}
  19. for k, v in host_fid_map:
  20. guest_fid_map[start] = v if not rename_host else v + '_' + host_sitename
  21. host_new_fid[k] = start
  22. start += 1
  23. guest_param['featureNameFidMapping'] = guest_fid_map
  24. # merging trees
  25. for tree_guest, tree_host in zip(guest_param['trees'], host_param['trees']):
  26. tree_guest['splitMaskdict'].update(tree_host['splitMaskdict'])
  27. tree_guest['missingDirMaskdict'].update(tree_host['missingDirMaskdict'])
  28. for node_g, node_h in zip(tree_guest['tree'], tree_host['tree']):
  29. if str(node_h['id']) in tree_host['splitMaskdict']:
  30. node_g['fid'] = int(host_new_fid[int(node_h['fid'])])
  31. node_g['sitename'] = host_sitename
  32. node_g['bid'] = 0
  33. return guest_param
  34. def extract_host_name(host_param, idx):
  35. try:
  36. anonymous_obj = Anonymous()
  37. anonymous_dict = host_param['anonymousNameMapping']
  38. role, party_id = None, None
  39. for key in anonymous_dict:
  40. role = anonymous_obj.get_role_from_anonymous_column(key)
  41. party_id = anonymous_obj.get_party_id_from_anonymous_column(key)
  42. break
  43. if role is not None and party_id is not None:
  44. return role + '_' + party_id
  45. else:
  46. return None
  47. except Exception as e:
  48. return 'host_{}'.format(idx)
  49. def merge_sbt(guest_param: dict, guest_meta: dict, host_params: list, host_metas: list, output_format: str,
  50. target_name='y', host_rename=True):
  51. result_param = None
  52. for idx, host_param in enumerate(host_params):
  53. host_name = extract_host_name(host_param, idx)
  54. if result_param is None:
  55. result_param = _merge_sbt(guest_param, host_param, host_name, host_rename)
  56. else:
  57. result_param = _merge_sbt(result_param, host_param, host_name, host_rename)
  58. pb_param = json_format.Parse(json.dumps(result_param), BoostingTreeModelParam())
  59. pb_meta = json_format.Parse(json.dumps(guest_meta), BoostingTreeModelMeta())
  60. lgb_model = sbt_to_lgb(pb_param, pb_meta, False)
  61. if output_format in ['lgb', 'lightgbm']:
  62. return lgb_model
  63. elif output_format in ['pmml']:
  64. classes = list(map(int, pb_param.classes_))
  65. bst = lgb.Booster(model_str=lgb_model)
  66. new_clf = lgb.LGBMRegressor() if guest_meta['taskType'] == 'regression' else lgb.LGBMClassifier()
  67. new_clf._Booster = bst
  68. new_clf._n_features = len(bst.feature_name())
  69. new_clf._n_classes = len(np.unique(classes))
  70. new_clf._le = _LGBMLabelEncoder().fit(np.array(classes))
  71. new_clf.fitted_ = True
  72. new_clf._classes = new_clf._le.classes_
  73. test_pipeline = Pipeline([("lgb", new_clf)])
  74. return test_pipeline
  75. else:
  76. raise ValueError('unknown output type {}'.format(output_format))