123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- #
- # Copyright 2021 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import importlib
- import inspect
- import os
- from federatedml.util import LOGGER
- from .component_converter import ComponentConverterBase
- SKLEARN_FILENAME = "sklearn.joblib"
- PYTORCH_FILENAME = "pytorch.pth"
- TF_DIRNAME = "tensorflow_saved_model"
- LGB_FILENAME = "lgb.txt"
- def _get_component_converter(module_name: str,
- framework_name: str):
- if framework_name in ["tensorflow", "tf", "tf_keras"]:
- framework_name = "tf_keras"
- elif framework_name in ["pytorch", "torch"]:
- framework_name = "pytorch"
- elif framework_name in ["sklearn", "scikit-learn"]:
- framework_name = "sklearn"
- elif framework_name in ['lightgbm']:
- framework_name = 'lightgbm'
- package_name = "." + framework_name
- parent_package = importlib.import_module(package_name, __package__)
- parent_package_path = os.path.dirname(os.path.realpath(parent_package.__file__))
- for f in os.listdir(parent_package_path):
- if f.startswith('.') or f.startswith('_'):
- continue
- if not f.endswith('.py'):
- continue
- proto_module = importlib.import_module("." + f.rstrip('.py'), parent_package.__name__)
- for name, obj in inspect.getmembers(proto_module):
- if inspect.isclass(obj) and issubclass(obj, ComponentConverterBase):
- for module in obj.get_target_modules():
- if module.lower() == module_name.lower():
- return framework_name, obj()
- return None, None
- def get_default_target_framework(model_contents: dict,
- module_name: str):
- """
- Returns the name of a supported ML framework based on the
- original FATE model module name and model contents.
- :param model_contents: the model content of the FATE model
- :param module_name: The module name, typically as HomoXXXX.
- :return: the corresponding framework name that this model can be converted to.
- """
- framework_name = None
- if module_name == "HomoLR":
- framework_name = "sklearn"
- elif module_name == 'HomoNN':
- # in FATE-1.10 currently support pytorch only
- framework_name = "pytorch"
- # if model_contents['HomoNNModelMeta'].params.config_type == "pytorch":
- # framework_name = "pytorch"
- # else:
- # framework_name = "tf_keras"
- elif module_name.lower() == 'homosecureboost':
- framework_name = 'lightgbm'
- else:
- LOGGER.debug(
- f"Module {module_name} is not a supported homogeneous model")
- return framework_name
- def model_convert(model_contents: dict,
- module_name: str,
- framework_name=None):
- """Convert a Homo model component into format of a common ML framework
- :param model_contents: The model dict un-serialized from the model protobuf.
- :param module_name: The module name, typically as HomoXXXX.
- :param framework_name: The wanted framework, e.g. "sklearn", "pytorch", etc.
- If not specified, the target framework will be chosen
- automatically.
- :return: the converted framework name and a instance of the model object from
- the specified framework.
- """
- if not framework_name:
- framework_name = get_default_target_framework(
- model_contents, module_name)
- if not framework_name:
- return None, None
- target_framework, component_converter = _get_component_converter(
- module_name, framework_name)
- if not component_converter:
- LOGGER.warn(
- f"Module {module_name} cannot be converted to framework {framework_name}")
- return None, None
- LOGGER.info(
- f"Converting {module_name} module to a model of framework {target_framework}")
- return target_framework, component_converter.convert(model_contents)
- def _get_model_saver_loader(framework_name: str):
- if framework_name in ["sklearn", "scikit-learn"]:
- import joblib
- return joblib.dump, joblib.load, SKLEARN_FILENAME
- elif framework_name in ["pytorch", "torch"]:
- import torch
- return torch.save, torch.load, PYTORCH_FILENAME
- elif framework_name in ["tensorflow", "tf", "tf_keras"]:
- import tensorflow
- return tensorflow.saved_model.save, tensorflow.saved_model.load, TF_DIRNAME
- elif framework_name in ['lightgbm']:
- from federatedml.protobuf.homo_model_convert.lightgbm.gbdt import save_lgb, load_lgb
- return save_lgb, load_lgb, LGB_FILENAME
- else:
- raise NotImplementedError("save method for framework: {} is not implemented"
- .format(framework_name))
- def save_converted_model(model_object,
- framework_name: str,
- base_dir: str):
- """Save the model into target destination
- :param model_object: the model object
- :param framework_name: name of the framework of the model
- :param base_dir: the base directory to save the model file
- :return: local file/folder path
- """
- save, _, dest_filename = _get_model_saver_loader(framework_name)
- dest = os.path.join(base_dir, dest_filename)
- save(model_object, dest)
- LOGGER.info(f"Saved {framework_name} model to {dest}")
- return dest
- def load_converted_model(framework_name: str,
- base_dir: str):
- """Load a model from the specified directory previously used to save the converted model
- :param framework_name: name of the framework of the model
- :param base_dir: the base directory to save the model file
- :return: model object of the specified framework
- """
- _, load, src_filename = _get_model_saver_loader(framework_name)
- src = os.path.join(base_dir, src_filename)
- if not os.path.exists(src):
- raise FileNotFoundError(
- "expected file or folder {} doesn't exist".format(src))
- return load(src)
|