cust_model.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import importlib
  2. from torch import nn
  3. from federatedml.nn.backend.torch.base import FateTorchLayer
  4. from federatedml.nn.backend.utils.common import ML_PATH
  5. PATH = '{}.model_zoo'.format(ML_PATH)
  6. class CustModel(FateTorchLayer, nn.Module):
  7. def __init__(self, module_name, class_name, **kwargs):
  8. super(CustModel, self).__init__()
  9. assert isinstance(
  10. module_name, str), 'name must be a str, specify the module in the model_zoo'
  11. assert isinstance(
  12. class_name, str), 'class name must be a str, specify the class in the module'
  13. self.param_dict = {
  14. 'module_name': module_name,
  15. 'class_name': class_name,
  16. 'param': kwargs}
  17. self._model = None
  18. def init_model(self):
  19. if self._model is None:
  20. self._model = self.get_pytorch_model()
  21. def forward(self, x):
  22. if self._model is None:
  23. raise ValueError('model not init, call init_model() function')
  24. return self._model(x)
  25. def get_pytorch_model(self):
  26. module_name: str = self.param_dict['module_name']
  27. class_name = self.param_dict['class_name']
  28. module_param: dict = self.param_dict['param']
  29. if module_name.endswith('.py'):
  30. module_name = module_name.replace('.py', '')
  31. nn_modules = importlib.import_module('{}.{}'.format(PATH, module_name))
  32. try:
  33. for k, v in nn_modules.__dict__.items():
  34. if isinstance(v, type):
  35. if issubclass(
  36. v, nn.Module) and v is not nn.Module and v.__name__ == class_name:
  37. return v(**module_param)
  38. raise ValueError(
  39. 'Did not find any class in {}.py that is pytorch nn.Module and named {}'. format(
  40. module_name, class_name))
  41. except ValueError as e:
  42. raise e
  43. def __repr__(self):
  44. return 'CustModel({})'.format(str(self.param_dict))