merge_hetero_models.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import copy
  2. import tempfile
  3. import json
  4. import pickle
  5. import base64
  6. from federatedml.protobuf.model_merge.merge_sbt import merge_sbt
  7. from federatedml.protobuf.model_merge.merge_hetero_lr import merge_lr
  8. from nyoka import lgb_to_pmml
  9. from sklearn2pmml import sklearn2pmml
  10. def get_pmml_str(pmml_pipeline, target_name):
  11. tmp_f = tempfile.NamedTemporaryFile()
  12. path = tmp_f.name
  13. lgb_to_pmml(pmml_pipeline, pmml_pipeline['lgb'].feature_name_, target_name, path)
  14. with open(path, 'r') as read_f:
  15. str_ = read_f.read()
  16. tmp_f.close()
  17. return str_
  18. def output_sklearn_pmml_str(pmml_pipeline, ):
  19. tmp_f = tempfile.NamedTemporaryFile()
  20. path = tmp_f.name
  21. sklearn2pmml(pmml_pipeline, path, with_repr=True)
  22. with open(path, 'r') as read_f:
  23. str_ = read_f.read()
  24. tmp_f.close()
  25. return str_
  26. def hetero_model_merge(guest_param: dict, guest_meta: dict, host_params: list, host_metas: list, model_type: str,
  27. output_format: str, target_name: str = 'y', host_rename=False, include_guest_coef=False):
  28. """
  29. Merge a hetero model
  30. :param guest_param: a json dict contains guest model param
  31. :param guest_meta: a json dict contains guest model meta
  32. :param host_params: a list contains json dicts of host params
  33. :param host_metas: a list contains json dicts of host metas
  34. :param model_type: specify the model type:
  35. secureboost, alias tree, sbt
  36. logistic_regression, alias LR
  37. :param output_format: output format of merged model, support:
  38. lightgbm, for tree models only
  39. sklearn, for linear models only
  40. pmml, for all types
  41. :param target_name: if output format is pmml, need to specify the targe(label) name
  42. :param host_rename: add suffix to secureboost host features
  43. :param include_guest_coef: default False
  44. :return: Merged Model Class
  45. """
  46. guest_param = copy.deepcopy(guest_param)
  47. guest_meta = copy.deepcopy(guest_meta)
  48. host_params = copy.deepcopy(host_params)
  49. host_metas = copy.deepcopy(host_metas)
  50. if not isinstance(model_type, str):
  51. raise ValueError('model type should be a str, but got {}'.format(model_type))
  52. if output_format.lower() not in {'lightgbm', 'lgb', 'sklearn', 'pmml'}:
  53. raise ValueError('unknown output format: {}'.format(output_format))
  54. if model_type.lower() in ['secureboost', 'tree', 'sbt']:
  55. model = merge_sbt(guest_param, guest_meta, host_params, host_metas, output_format, target_name,
  56. host_rename=host_rename)
  57. if output_format == 'pmml':
  58. return get_pmml_str(model, target_name)
  59. else:
  60. return model
  61. elif model_type.lower() in {'logistic_regression', 'lr'}:
  62. model = merge_lr(guest_param, guest_meta, host_params, host_metas, output_format, include_guest_coef)
  63. if output_format == 'pmml':
  64. return output_sklearn_pmml_str(model)
  65. else:
  66. return json.dumps(str(base64.b64encode(pickle.dumps(model)), "utf-8"))
  67. else:
  68. raise ValueError('model type should be one in ["sbt", "lr"], '
  69. 'but got unknown model type: {}'.format(model_type))