tree_model_converter.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from typing import Dict
  18. from federatedml.util import consts
  19. from federatedml.protobuf.generated.boosting_tree_model_meta_pb2 import BoostingTreeModelMeta
  20. from federatedml.protobuf.generated.boosting_tree_model_param_pb2 import BoostingTreeModelParam
  21. from federatedml.protobuf.model_migrate.converter.converter_base import AutoReplace
  22. from federatedml.protobuf.model_migrate.converter.converter_base import ProtoConverterBase
  23. class HeteroSBTConverter(ProtoConverterBase):
  24. def convert(self, param: BoostingTreeModelParam, meta: BoostingTreeModelMeta,
  25. guest_id_mapping: Dict,
  26. host_id_mapping: Dict,
  27. arbiter_id_mapping: Dict,
  28. tree_plan_delimiter='_'
  29. ):
  30. feat_importance_list = list(param.feature_importances)
  31. fid_feature_mapping = dict(param.feature_name_fid_mapping)
  32. feature_fid_mapping = {v: k for k, v in fid_feature_mapping.items()}
  33. tree_list = list(param.trees_)
  34. tree_plan = list(param.tree_plan)
  35. replacer = AutoReplace(guest_id_mapping, host_id_mapping, arbiter_id_mapping)
  36. # fp == feature importance
  37. for fp in feat_importance_list:
  38. fp.sitename = replacer.replace(fp.sitename)
  39. if fp.fullname not in feature_fid_mapping:
  40. fp.fullname = replacer.migrate_anonymous_header(fp.fullname)
  41. for tree in tree_list:
  42. tree_nodes = list(tree.tree_)
  43. for node in tree_nodes:
  44. node.sitename = replacer.replace(node.sitename)
  45. new_tree_plan = []
  46. for str_tuple in tree_plan:
  47. param.tree_plan.remove(str_tuple)
  48. tree_mode, party_id = str_tuple.split(tree_plan_delimiter)
  49. if int(party_id) != -1:
  50. new_party_id = replacer.plain_replace(party_id, role=consts.HOST)
  51. else:
  52. new_party_id = party_id
  53. new_tree_plan.append(tree_mode + tree_plan_delimiter + new_party_id)
  54. param.tree_plan.extend(new_tree_plan)
  55. return param, meta