12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import importlib
- from torch import nn
- from federatedml.nn.backend.torch.base import FateTorchLayer
- from federatedml.nn.backend.utils.common import ML_PATH
- PATH = '{}.model_zoo'.format(ML_PATH)
- class CustModel(FateTorchLayer, nn.Module):
- def __init__(self, module_name, class_name, **kwargs):
- super(CustModel, self).__init__()
- assert isinstance(
- module_name, str), 'name must be a str, specify the module in the model_zoo'
- assert isinstance(
- class_name, str), 'class name must be a str, specify the class in the module'
- self.param_dict = {
- 'module_name': module_name,
- 'class_name': class_name,
- 'param': kwargs}
- self._model = None
- def init_model(self):
- if self._model is None:
- self._model = self.get_pytorch_model()
- def forward(self, x):
- if self._model is None:
- raise ValueError('model not init, call init_model() function')
- return self._model(x)
- def get_pytorch_model(self):
- module_name: str = self.param_dict['module_name']
- class_name = self.param_dict['class_name']
- module_param: dict = self.param_dict['param']
- if module_name.endswith('.py'):
- module_name = module_name.replace('.py', '')
- nn_modules = importlib.import_module('{}.{}'.format(PATH, module_name))
- try:
- for k, v in nn_modules.__dict__.items():
- if isinstance(v, type):
- if issubclass(
- v, nn.Module) and v is not nn.Module and v.__name__ == class_name:
- return v(**module_param)
- raise ValueError(
- 'Did not find any class in {}.py that is pytorch nn.Module and named {}'. format(
- module_name, class_name))
- except ValueError as e:
- raise e
- def __repr__(self):
- return 'CustModel({})'.format(str(self.param_dict))
|