serialization.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import copy
  2. import inspect
  3. from collections import OrderedDict
  4. try:
  5. from torch.nn import Sequential as tSeq
  6. from federatedml.nn.backend.torch import optim, init, nn
  7. from federatedml.nn.backend.torch import operation
  8. from federatedml.nn.backend.torch.base import Sequential, get_torch_instance
  9. from federatedml.nn.backend.torch.cust import CustModel, CustLoss
  10. from federatedml.nn.backend.torch.interactive import InteractiveLayer
  11. except ImportError:
  12. pass
  13. def recover_layer_from_dict(nn_define, nn_dict):
  14. init_param_dict = copy.deepcopy(nn_define)
  15. if 'layer' in nn_define:
  16. class_name = nn_define['layer']
  17. init_param_dict.pop('layer')
  18. elif 'op' in nn_define:
  19. class_name = nn_define['op']
  20. init_param_dict.pop('op')
  21. else:
  22. raise ValueError(
  23. 'no layer or operation info found in nn define, please check your layer config and make'
  24. 'sure they are correct for pytorch backend')
  25. if 'initializer' in init_param_dict:
  26. init_param_dict.pop('initializer')
  27. # find corresponding class
  28. if class_name == CustModel.__name__:
  29. nn_layer_class = CustModel
  30. elif class_name == InteractiveLayer.__name__:
  31. nn_layer_class = InteractiveLayer
  32. else:
  33. nn_layer_class = nn_dict[class_name]
  34. # create layer or Module
  35. if nn_layer_class == CustModel: # converto to pytorch model
  36. layer: CustModel = CustModel(module_name=init_param_dict['module_name'],
  37. class_name=init_param_dict['class_name'],
  38. **init_param_dict['param'])
  39. layer = layer.get_pytorch_model()
  40. elif nn_layer_class == InteractiveLayer:
  41. layer: InteractiveLayer = InteractiveLayer(**init_param_dict)
  42. else:
  43. layer = get_torch_instance(nn_layer_class, init_param_dict)
  44. # initialize if there are configs
  45. if 'initializer' in nn_define:
  46. if 'weight' in nn_define['initializer']:
  47. init_para = nn_define['initializer']['weight']
  48. init_func = init.str_fate_torch_init_func_map[init_para['init_func']]
  49. init_func(layer, **init_para['param'])
  50. if 'bias' in nn_define['initializer']:
  51. init_para = nn_define['initializer']['bias']
  52. init_func = init.str_fate_torch_init_func_map[init_para['init_func']]
  53. init_func(layer, init='bias', **init_para['param'])
  54. return layer, class_name
  55. def recover_sequential_from_dict(nn_define):
  56. nn_define_dict = nn_define
  57. nn_dict = dict(inspect.getmembers(nn))
  58. op_dict = dict(inspect.getmembers(operation))
  59. nn_dict.update(op_dict)
  60. class_name_list = []
  61. try:
  62. # submitted model have int prefixes, they make sure that layers are in
  63. # order
  64. add_dict = OrderedDict()
  65. keys = list(nn_define_dict.keys())
  66. keys = sorted(keys, key=lambda x: int(x.split('-')[0]))
  67. for k in keys:
  68. layer, class_name = recover_layer_from_dict(nn_define_dict[k], nn_dict)
  69. add_dict[k] = layer
  70. class_name_list.append(class_name)
  71. except BaseException:
  72. add_dict = OrderedDict()
  73. for k, v in nn_define_dict.items():
  74. layer, class_name = recover_layer_from_dict(v, nn_dict)
  75. add_dict[k] = layer
  76. class_name_list.append(class_name)
  77. if len(class_name_list) == 1 and class_name_list[0] == CustModel.__name__:
  78. # If there are only a CustModel, return the model only
  79. return list(add_dict.values())[0]
  80. else:
  81. return tSeq(add_dict)
  82. def recover_optimizer_from_dict(define_dict):
  83. opt_dict = dict(inspect.getmembers(optim))
  84. from federatedml.util import LOGGER
  85. LOGGER.debug('define dict is {}'.format(define_dict))
  86. if 'optimizer' not in define_dict:
  87. raise ValueError('please specify optimizer type in the json config')
  88. opt_class = opt_dict[define_dict['optimizer']]
  89. param_dict = copy.deepcopy(define_dict)
  90. if 'optimizer' in param_dict:
  91. param_dict.pop('optimizer')
  92. if 'config_type' in param_dict:
  93. param_dict.pop('config_type')
  94. return opt_class(**param_dict)
  95. def recover_loss_fn_from_dict(define_dict):
  96. loss_fn_dict = dict(inspect.getmembers(nn))
  97. if 'loss_fn' not in define_dict:
  98. raise ValueError('please specify loss function in the json config')
  99. param_dict = copy.deepcopy(define_dict)
  100. param_dict.pop('loss_fn')
  101. if define_dict['loss_fn'] == CustLoss.__name__:
  102. return CustLoss(loss_module_name=param_dict['loss_module_name'],
  103. class_name=param_dict['class_name'],
  104. **param_dict['param']).get_pytorch_model()
  105. else:
  106. return loss_fn_dict[define_dict['loss_fn']](**param_dict)
  107. if __name__ == '__main__':
  108. pass