base.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import json
  2. import torch as t
  3. from torch.nn import Sequential as tSequential
  4. from federatedml.nn.backend.torch.operation import OpBase
  5. class FateTorchLayer(object):
  6. def __init__(self):
  7. t.nn.Module.__init__(self)
  8. self.param_dict = dict()
  9. self.initializer = {'weight': None, 'bias': None}
  10. self.optimizer = None
  11. def to_dict(self):
  12. import copy
  13. ret_dict = copy.deepcopy(self.param_dict)
  14. ret_dict['layer'] = type(self).__name__
  15. ret_dict['initializer'] = {}
  16. if self.initializer['weight']:
  17. ret_dict['initializer']['weight'] = self.initializer['weight']
  18. if self.initializer['bias']:
  19. ret_dict['initializer']['bias'] = self.initializer['bias']
  20. return ret_dict
  21. def add_optimizer(self, opt):
  22. self.optimizer = opt
  23. class FateTorchLoss(object):
  24. def __init__(self):
  25. self.param_dict = {}
  26. def to_dict(self):
  27. import copy
  28. ret_dict = copy.deepcopy(self.param_dict)
  29. ret_dict['loss_fn'] = type(self).__name__
  30. return ret_dict
  31. class FateTorchOptimizer(object):
  32. def __init__(self):
  33. self.param_dict = dict()
  34. self.torch_class = None
  35. def to_dict(self):
  36. import copy
  37. ret_dict = copy.deepcopy(self.param_dict)
  38. ret_dict['optimizer'] = type(self).__name__
  39. ret_dict['config_type'] = 'pytorch'
  40. return ret_dict
  41. def check_params(self, params):
  42. if isinstance(
  43. params,
  44. FateTorchLayer) or isinstance(
  45. params,
  46. Sequential):
  47. params.add_optimizer(self)
  48. params = params.parameters()
  49. else:
  50. params = params
  51. l_param = list(params)
  52. if len(l_param) == 0:
  53. # fake parameters, for the case that there are only cust model
  54. return [t.nn.Parameter(t.Tensor([0]))]
  55. return l_param
  56. def register_optimizer(self, input_):
  57. if input_ is None:
  58. return
  59. if isinstance(
  60. input_,
  61. FateTorchLayer) or isinstance(
  62. input_,
  63. Sequential):
  64. input_.add_optimizer(self)
  65. def to_torch_instance(self, parameters):
  66. return self.torch_class(parameters, **self.param_dict)
  67. class Sequential(tSequential):
  68. def to_dict(self):
  69. """
  70. get the structure of current sequential
  71. """
  72. rs = {}
  73. idx = 0
  74. for k in self._modules:
  75. ordered_name = str(idx) + '-' + k
  76. rs[ordered_name] = self._modules[k].to_dict()
  77. idx += 1
  78. return rs
  79. def to_json(self):
  80. return json.dumps(self.to_dict(), indent=4)
  81. def add_optimizer(self, opt):
  82. setattr(self, 'optimizer', opt)
  83. def add(self, layer):
  84. if isinstance(layer, Sequential):
  85. self._modules = layer._modules
  86. # copy optimizer
  87. if hasattr(layer, 'optimizer'):
  88. setattr(self, 'optimizer', layer.optimizer)
  89. elif isinstance(layer, FateTorchLayer):
  90. self.add_module(str(len(self)), layer)
  91. # update optimizer if dont have
  92. if not hasattr(self, 'optimizer') and hasattr(layer, 'optimizer'):
  93. setattr(self, 'optimizer', layer.optimizer)
  94. else:
  95. raise ValueError(
  96. 'unknown input layer type {}, this type is not supported'.format(
  97. type(layer)))
  98. @staticmethod
  99. def get_loss_config(loss: FateTorchLoss):
  100. return loss.to_dict()
  101. def get_optimizer_config(self, optimizer=None):
  102. if hasattr(self, 'optimizer'):
  103. return self.optimizer.to_dict()
  104. else:
  105. return optimizer.to_dict()
  106. def get_network_config(self):
  107. return self.to_dict()
  108. def get_torch_instance(fate_torch_nn_class: FateTorchLayer, param):
  109. parent_torch_class = fate_torch_nn_class.__bases__
  110. if issubclass(fate_torch_nn_class, OpBase):
  111. return fate_torch_nn_class(**param)
  112. for cls in parent_torch_class:
  113. if issubclass(cls, t.nn.Module):
  114. return cls(**param)
  115. return None