import_hook.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. try:
  2. from federatedml.component.nn.backend.torch import nn as nn_
  3. from federatedml.component.nn.backend.torch import init as init_
  4. from federatedml.component.nn.backend.torch import optim as optim_
  5. from federatedml.component.nn.backend.torch.cust import CustModel, CustLoss
  6. from federatedml.nn.backend.torch.interactive import InteractiveLayer
  7. except ImportError:
  8. pass
  9. def monkey_patch(torch_nn, fate_torch_module):
  10. for name in fate_torch_module.__dict__.keys():
  11. if '__' in name: # skip no related variables
  12. continue
  13. if name in torch_nn.__dict__.keys():
  14. torch_nn.__dict__[name] = fate_torch_module.__dict__[name]
  15. def fate_torch_hook(torch_module_var):
  16. """
  17. This is a monkey patch function that modify torch modules to use fate_torch layers and Components
  18. :param torch_module_var:
  19. :return:
  20. """
  21. if torch_module_var.__name__ == 'torch':
  22. monkey_patch(torch_module_var.nn, nn_)
  23. monkey_patch(torch_module_var.optim, optim_)
  24. monkey_patch(torch_module_var.nn.init, init_)
  25. setattr(torch_module_var.nn, 'CustModel', CustModel)
  26. setattr(torch_module_var.nn, 'InteractiveLayer', InteractiveLayer)
  27. setattr(torch_module_var.nn, 'CustLoss', CustLoss)
  28. elif torch_module_var.__name__ == 'torch.nn':
  29. monkey_patch(torch_module_var, nn_)
  30. setattr(torch_module_var, 'CustModel', CustModel)
  31. setattr(torch_module_var.nn, 'InteractiveLayer', InteractiveLayer)
  32. setattr(torch_module_var.nn, 'CustLoss', CustLoss)
  33. elif torch_module_var.__name__ == 'torch.optim':
  34. monkey_patch(torch_module_var, optim_)
  35. elif torch_module_var.__name__ == 'torch.nn.init':
  36. monkey_patch(torch_module_var, init_)
  37. else:
  38. raise ValueError(
  39. 'this module: {} does not support fate torch hook'.format(torch_module_var))
  40. return torch_module_var