homo_model_convert.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #
  2. # Copyright 2021 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import importlib
  17. import inspect
  18. import os
  19. from federatedml.util import LOGGER
  20. from .component_converter import ComponentConverterBase
  21. SKLEARN_FILENAME = "sklearn.joblib"
  22. PYTORCH_FILENAME = "pytorch.pth"
  23. TF_DIRNAME = "tensorflow_saved_model"
  24. LGB_FILENAME = "lgb.txt"
  25. def _get_component_converter(module_name: str,
  26. framework_name: str):
  27. if framework_name in ["tensorflow", "tf", "tf_keras"]:
  28. framework_name = "tf_keras"
  29. elif framework_name in ["pytorch", "torch"]:
  30. framework_name = "pytorch"
  31. elif framework_name in ["sklearn", "scikit-learn"]:
  32. framework_name = "sklearn"
  33. elif framework_name in ['lightgbm']:
  34. framework_name = 'lightgbm'
  35. package_name = "." + framework_name
  36. parent_package = importlib.import_module(package_name, __package__)
  37. parent_package_path = os.path.dirname(os.path.realpath(parent_package.__file__))
  38. for f in os.listdir(parent_package_path):
  39. if f.startswith('.') or f.startswith('_'):
  40. continue
  41. if not f.endswith('.py'):
  42. continue
  43. proto_module = importlib.import_module("." + f.rstrip('.py'), parent_package.__name__)
  44. for name, obj in inspect.getmembers(proto_module):
  45. if inspect.isclass(obj) and issubclass(obj, ComponentConverterBase):
  46. for module in obj.get_target_modules():
  47. if module.lower() == module_name.lower():
  48. return framework_name, obj()
  49. return None, None
  50. def get_default_target_framework(model_contents: dict,
  51. module_name: str):
  52. """
  53. Returns the name of a supported ML framework based on the
  54. original FATE model module name and model contents.
  55. :param model_contents: the model content of the FATE model
  56. :param module_name: The module name, typically as HomoXXXX.
  57. :return: the corresponding framework name that this model can be converted to.
  58. """
  59. framework_name = None
  60. if module_name == "HomoLR":
  61. framework_name = "sklearn"
  62. elif module_name == 'HomoNN':
  63. # in FATE-1.10 currently support pytorch only
  64. framework_name = "pytorch"
  65. # if model_contents['HomoNNModelMeta'].params.config_type == "pytorch":
  66. # framework_name = "pytorch"
  67. # else:
  68. # framework_name = "tf_keras"
  69. elif module_name.lower() == 'homosecureboost':
  70. framework_name = 'lightgbm'
  71. else:
  72. LOGGER.debug(
  73. f"Module {module_name} is not a supported homogeneous model")
  74. return framework_name
  75. def model_convert(model_contents: dict,
  76. module_name: str,
  77. framework_name=None):
  78. """Convert a Homo model component into format of a common ML framework
  79. :param model_contents: The model dict un-serialized from the model protobuf.
  80. :param module_name: The module name, typically as HomoXXXX.
  81. :param framework_name: The wanted framework, e.g. "sklearn", "pytorch", etc.
  82. If not specified, the target framework will be chosen
  83. automatically.
  84. :return: the converted framework name and a instance of the model object from
  85. the specified framework.
  86. """
  87. if not framework_name:
  88. framework_name = get_default_target_framework(
  89. model_contents, module_name)
  90. if not framework_name:
  91. return None, None
  92. target_framework, component_converter = _get_component_converter(
  93. module_name, framework_name)
  94. if not component_converter:
  95. LOGGER.warn(
  96. f"Module {module_name} cannot be converted to framework {framework_name}")
  97. return None, None
  98. LOGGER.info(
  99. f"Converting {module_name} module to a model of framework {target_framework}")
  100. return target_framework, component_converter.convert(model_contents)
  101. def _get_model_saver_loader(framework_name: str):
  102. if framework_name in ["sklearn", "scikit-learn"]:
  103. import joblib
  104. return joblib.dump, joblib.load, SKLEARN_FILENAME
  105. elif framework_name in ["pytorch", "torch"]:
  106. import torch
  107. return torch.save, torch.load, PYTORCH_FILENAME
  108. elif framework_name in ["tensorflow", "tf", "tf_keras"]:
  109. import tensorflow
  110. return tensorflow.saved_model.save, tensorflow.saved_model.load, TF_DIRNAME
  111. elif framework_name in ['lightgbm']:
  112. from federatedml.protobuf.homo_model_convert.lightgbm.gbdt import save_lgb, load_lgb
  113. return save_lgb, load_lgb, LGB_FILENAME
  114. else:
  115. raise NotImplementedError("save method for framework: {} is not implemented"
  116. .format(framework_name))
  117. def save_converted_model(model_object,
  118. framework_name: str,
  119. base_dir: str):
  120. """Save the model into target destination
  121. :param model_object: the model object
  122. :param framework_name: name of the framework of the model
  123. :param base_dir: the base directory to save the model file
  124. :return: local file/folder path
  125. """
  126. save, _, dest_filename = _get_model_saver_loader(framework_name)
  127. dest = os.path.join(base_dir, dest_filename)
  128. save(model_object, dest)
  129. LOGGER.info(f"Saved {framework_name} model to {dest}")
  130. return dest
  131. def load_converted_model(framework_name: str,
  132. base_dir: str):
  133. """Load a model from the specified directory previously used to save the converted model
  134. :param framework_name: name of the framework of the model
  135. :param base_dir: the base directory to save the model file
  136. :return: model object of the specified framework
  137. """
  138. _, load, src_filename = _get_model_saver_loader(framework_name)
  139. src = os.path.join(base_dir, src_filename)
  140. if not os.path.exists(src):
  141. raise FileNotFoundError(
  142. "expected file or folder {} doesn't exist".format(src))
  143. return load(src)