123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- from torch import nn
- import importlib
- from federatedml.nn.backend.torch.base import FateTorchLayer, FateTorchLoss
- from federatedml.nn.backend.utils.common import ML_PATH
- import difflib
- MODEL_PATH = '{}.model_zoo'.format(ML_PATH)
- LOSS_PATH = '{}.loss'.format(ML_PATH)
- def str_simi(str_a, str_b):
- return difflib.SequenceMatcher(None, str_a, str_b).quick_ratio()
- def get_class(module_name, class_name, param, base_path):
- if module_name.endswith('.py'):
- module_name = module_name.replace('.py', '')
- nn_modules = importlib.import_module(
- '{}.{}'.format(base_path, module_name))
- try:
- name_simi_list = []
- for k, v in nn_modules.__dict__.items():
- if isinstance(v, type):
- if issubclass(v, nn.Module) and v is not nn.Module:
- if v.__name__ == class_name:
- print(param)
- return v(**param)
- else:
- name_simi_list += ([(str_simi(class_name, v.__name__), v)])
- sort_by_simi = sorted(name_simi_list, key=lambda x: -x[0])
- if len(sort_by_simi) > 0:
- raise ValueError(
- 'Did not find any class in {}.py that is subclass of nn.Module and named {}. Do you mean {}?'. format(
- module_name, class_name, sort_by_simi[0][1].__name__))
- else:
- raise ValueError(
- 'Did not find any class in {}.py that is subclass of nn.Module and named {}'. format(
- module_name, class_name))
- except ValueError as e:
- raise e
- class CustModel(FateTorchLayer, nn.Module):
- def __init__(self, module_name, class_name, **kwargs):
- super(CustModel, self).__init__()
- assert isinstance(
- module_name, str), 'name must be a str, specify the module in the model_zoo'
- assert isinstance(
- class_name, str), 'class name must be a str, specify the class in the module'
- self.param_dict = {
- 'module_name': module_name,
- 'class_name': class_name,
- 'param': kwargs}
- self._model = None
- def init_model(self):
- if self._model is None:
- self._model = self.get_pytorch_model()
- def forward(self, x):
- if self._model is None:
- raise ValueError('model not init, call init_model() function')
- return self._model(x)
- def get_pytorch_model(self, module_path=None):
- if module_path is None:
- return get_class(
- self.param_dict['module_name'],
- self.param_dict['class_name'],
- self.param_dict['param'],
- MODEL_PATH)
- else:
- return get_class(
- self.param_dict['module_name'],
- self.param_dict['class_name'],
- self.param_dict['param'],
- module_path)
- def __repr__(self):
- return 'CustModel({})'.format(str(self.param_dict))
- class CustLoss(FateTorchLoss, nn.Module):
- def __init__(self, loss_module_name, class_name, **kwargs):
- super(CustLoss, self).__init__()
- assert isinstance(
- loss_module_name, str), 'loss module name must be a str, specify the module in the model_zoo'
- assert isinstance(
- class_name, str), 'class name must be a str, specify the class in the module'
- self.param_dict = {
- 'loss_module_name': loss_module_name,
- 'class_name': class_name,
- 'param': kwargs}
- self._loss_fn = None
- def init_loss_fn(self):
- if self._loss_fn is None:
- self._loss_fn = self.get_pytorch_model()
- def forward(self, pred, label):
- if self._loss_fn is None:
- raise ValueError('loss not init, call init_loss_fn() function')
- return self._loss_fn(pred, label)
- def get_pytorch_model(self, module_path=None):
- module_name: str = self.param_dict['loss_module_name']
- class_name: str = self.param_dict['class_name']
- module_param: dict = self.param_dict['param']
- if module_path is None:
- return get_class(
- module_name=module_name,
- class_name=class_name,
- param=module_param,
- base_path=LOSS_PATH)
- else:
- return get_class(
- module_name=module_name,
- class_name=class_name,
- param=module_param,
- base_path=module_path)
- def __repr__(self):
- return 'CustLoss({})'.format(str(self.param_dict))
|