cust.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from torch import nn
  2. import importlib
  3. from federatedml.nn.backend.torch.base import FateTorchLayer, FateTorchLoss
  4. from federatedml.nn.backend.utils.common import ML_PATH
  5. import difflib
  6. MODEL_PATH = '{}.model_zoo'.format(ML_PATH)
  7. LOSS_PATH = '{}.loss'.format(ML_PATH)
  8. def str_simi(str_a, str_b):
  9. return difflib.SequenceMatcher(None, str_a, str_b).quick_ratio()
  10. def get_class(module_name, class_name, param, base_path):
  11. if module_name.endswith('.py'):
  12. module_name = module_name.replace('.py', '')
  13. nn_modules = importlib.import_module(
  14. '{}.{}'.format(base_path, module_name))
  15. try:
  16. name_simi_list = []
  17. for k, v in nn_modules.__dict__.items():
  18. if isinstance(v, type):
  19. if issubclass(v, nn.Module) and v is not nn.Module:
  20. if v.__name__ == class_name:
  21. print(param)
  22. return v(**param)
  23. else:
  24. name_simi_list += ([(str_simi(class_name, v.__name__), v)])
  25. sort_by_simi = sorted(name_simi_list, key=lambda x: -x[0])
  26. if len(sort_by_simi) > 0:
  27. raise ValueError(
  28. 'Did not find any class in {}.py that is subclass of nn.Module and named {}. Do you mean {}?'. format(
  29. module_name, class_name, sort_by_simi[0][1].__name__))
  30. else:
  31. raise ValueError(
  32. 'Did not find any class in {}.py that is subclass of nn.Module and named {}'. format(
  33. module_name, class_name))
  34. except ValueError as e:
  35. raise e
  36. class CustModel(FateTorchLayer, nn.Module):
  37. def __init__(self, module_name, class_name, **kwargs):
  38. super(CustModel, self).__init__()
  39. assert isinstance(
  40. module_name, str), 'name must be a str, specify the module in the model_zoo'
  41. assert isinstance(
  42. class_name, str), 'class name must be a str, specify the class in the module'
  43. self.param_dict = {
  44. 'module_name': module_name,
  45. 'class_name': class_name,
  46. 'param': kwargs}
  47. self._model = None
  48. def init_model(self):
  49. if self._model is None:
  50. self._model = self.get_pytorch_model()
  51. def forward(self, x):
  52. if self._model is None:
  53. raise ValueError('model not init, call init_model() function')
  54. return self._model(x)
  55. def get_pytorch_model(self, module_path=None):
  56. if module_path is None:
  57. return get_class(
  58. self.param_dict['module_name'],
  59. self.param_dict['class_name'],
  60. self.param_dict['param'],
  61. MODEL_PATH)
  62. else:
  63. return get_class(
  64. self.param_dict['module_name'],
  65. self.param_dict['class_name'],
  66. self.param_dict['param'],
  67. module_path)
  68. def __repr__(self):
  69. return 'CustModel({})'.format(str(self.param_dict))
  70. class CustLoss(FateTorchLoss, nn.Module):
  71. def __init__(self, loss_module_name, class_name, **kwargs):
  72. super(CustLoss, self).__init__()
  73. assert isinstance(
  74. loss_module_name, str), 'loss module name must be a str, specify the module in the model_zoo'
  75. assert isinstance(
  76. class_name, str), 'class name must be a str, specify the class in the module'
  77. self.param_dict = {
  78. 'loss_module_name': loss_module_name,
  79. 'class_name': class_name,
  80. 'param': kwargs}
  81. self._loss_fn = None
  82. def init_loss_fn(self):
  83. if self._loss_fn is None:
  84. self._loss_fn = self.get_pytorch_model()
  85. def forward(self, pred, label):
  86. if self._loss_fn is None:
  87. raise ValueError('loss not init, call init_loss_fn() function')
  88. return self._loss_fn(pred, label)
  89. def get_pytorch_model(self, module_path=None):
  90. module_name: str = self.param_dict['loss_module_name']
  91. class_name: str = self.param_dict['class_name']
  92. module_param: dict = self.param_dict['param']
  93. if module_path is None:
  94. return get_class(
  95. module_name=module_name,
  96. class_name=class_name,
  97. param=module_param,
  98. base_path=LOSS_PATH)
  99. else:
  100. return get_class(
  101. module_name=module_name,
  102. class_name=class_name,
  103. param=module_param,
  104. base_path=module_path)
  105. def __repr__(self):
  106. return 'CustLoss({})'.format(str(self.param_dict))