test_tree_converter.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from federatedml.protobuf.model_migrate.converter.tree_model_converter import HeteroSBTConverter
  2. from federatedml.protobuf.generated.boosting_tree_model_param_pb2 import BoostingTreeModelParam, NodeParam, \
  3. DecisionTreeModelParam, FeatureImportanceInfo
  4. from federatedml.protobuf.generated.boosting_tree_model_meta_pb2 import BoostingTreeModelMeta
  5. from federatedml.protobuf.model_migrate.converter.tree_model_converter import HeteroSBTConverter
  6. from federatedml.protobuf.model_migrate.model_migrate import model_migration
  7. import copy
  8. host_old = [10000, 9999]
  9. host_new = [114, 514, ]
  10. guest_old = [10000]
  11. guest_new = [1919]
  12. param = BoostingTreeModelParam()
  13. fp0 = FeatureImportanceInfo()
  14. fp0.fullname = 'host_10000_0'
  15. fp0.sitename = 'host:10000'
  16. fp1 = FeatureImportanceInfo()
  17. fp1.sitename = 'host:9999'
  18. fp1.fullname = 'host_9999_1'
  19. fp2 = FeatureImportanceInfo(fullname='x0')
  20. fp2.sitename = 'guest:10000'
  21. feature_importance = [fp0, fp1, fp2]
  22. param.feature_importances.extend(feature_importance)
  23. tree_0 = DecisionTreeModelParam(tree_=[NodeParam(sitename='guest:10000'), NodeParam(sitename='guest:10000'),
  24. NodeParam(sitename='guest:10000')])
  25. tree_1 = DecisionTreeModelParam(tree_=[NodeParam(sitename='host:10000'), NodeParam(sitename='host:9999'),
  26. NodeParam(sitename='host:10000')])
  27. tree_2 = DecisionTreeModelParam(tree_=[NodeParam(sitename='host:9999'), NodeParam(sitename='guest:10000'),
  28. NodeParam(sitename='host:9999')])
  29. tree_3 = DecisionTreeModelParam()
  30. param.trees_.extend([tree_0, tree_1, tree_2, tree_3])
  31. rs = model_migration({'HelloParam': param, 'HelloMeta': {}}, 'HeteroSecureBoost', old_guest_list=guest_old,
  32. new_guest_list=guest_new, old_host_list=host_old, new_host_list=host_new, )