torch_model.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import numpy as np
  2. import tempfile
  3. from federatedml.util import LOGGER
  4. try: # for the situation that torch is not installed, but other modules still can be used
  5. import torch
  6. import torch as t
  7. import copy
  8. from types import SimpleNamespace
  9. from torch import autograd
  10. from federatedml.nn.backend.torch import serialization as s
  11. from federatedml.nn.backend.torch.base import FateTorchOptimizer
  12. from federatedml.nn.backend.torch.nn import CrossEntropyLoss
  13. from federatedml.nn.backend.torch import optim
  14. except ImportError:
  15. pass
  16. def backward_loss(z, backward_error):
  17. return t.sum(z * backward_error)
  18. class TorchNNModel(object):
  19. def __init__(self, nn_define: dict, optimizer_define: dict = None, loss_fn_define: dict = None, cuda=False):
  20. self.cuda = cuda
  21. self.double_model = False
  22. if self.cuda and not t.cuda.is_available():
  23. raise ValueError(
  24. 'this machine dose not support cuda, cuda.is_available() is False')
  25. self.optimizer_define = optimizer_define
  26. self.nn_define = nn_define
  27. self.loss_fn_define = loss_fn_define
  28. self.loss_history = []
  29. self.model, self.opt_inst, self.loss_fn = self.init(
  30. self.nn_define, self.optimizer_define, self.loss_fn_define)
  31. self.fw_cached = None
  32. def to_tensor(self, x: np.ndarray):
  33. if isinstance(x, np.ndarray):
  34. x = t.from_numpy(x)
  35. if self.cuda:
  36. return x.cuda()
  37. else:
  38. return x
  39. def label_convert(self, y, loss_fn):
  40. # pytorch CE loss require 1D-int64-tensor
  41. if isinstance(loss_fn, CrossEntropyLoss):
  42. return t.Tensor(y).flatten().type(
  43. t.int64).flatten() # accept 1-D array
  44. else:
  45. return t.Tensor(y).type(t.float)
  46. def init(self, nn_define: dict, optimizer_define: dict = None, loss_fn_define: dict = None):
  47. model = s.recover_sequential_from_dict(nn_define)
  48. if self.cuda:
  49. model = model.cuda()
  50. if optimizer_define is None: # default optimizer
  51. optimizer = optim.SGD(lr=0.01)
  52. else:
  53. optimizer: FateTorchOptimizer = s.recover_optimizer_from_dict(optimizer_define)
  54. opt_inst = optimizer.to_torch_instance(model.parameters())
  55. if loss_fn_define is None:
  56. loss_fn = backward_loss
  57. else:
  58. loss_fn = s.recover_loss_fn_from_dict(loss_fn_define)
  59. if self.double_model:
  60. self.model.type(t.float64)
  61. return model, opt_inst, loss_fn
  62. def print_parameters(self):
  63. LOGGER.debug(
  64. 'model parameter is {}'.format(
  65. list(
  66. self.model.parameters())))
  67. def __repr__(self):
  68. return self.model.__repr__() + '\n' + self.opt_inst.__repr__() + \
  69. '\n' + str(self.loss_fn)
  70. def train_mode(self, mode):
  71. self.model.train(mode)
  72. def train(self, data_x_and_y):
  73. x, y = data_x_and_y # this is a tuple
  74. self.opt_inst.zero_grad()
  75. yt = self.to_tensor(y)
  76. xt = self.to_tensor(x)
  77. out = self.model(xt)
  78. loss = self.loss_fn(out, yt)
  79. loss.backward()
  80. loss_val = loss.cpu().detach().numpy()
  81. self.loss_history.append(loss_val)
  82. self.opt_inst.step()
  83. return loss_val
  84. def forward(self, x):
  85. # will cache tensor with grad, this function is especially for bottom
  86. # model
  87. x = self.to_tensor(x)
  88. out = self.model(x)
  89. if self.fw_cached is not None:
  90. raise ValueError('fed cached should be None when forward')
  91. self.fw_cached = out
  92. return out.cpu().detach().numpy()
  93. def backward(self, error):
  94. # backward ,this function is especially for bottom model
  95. self.opt_inst.zero_grad()
  96. error = self.to_tensor(error)
  97. loss = self.loss_fn(self.fw_cached, error)
  98. loss.backward()
  99. self.fw_cached = None
  100. self.opt_inst.step()
  101. def predict(self, x):
  102. with torch.no_grad():
  103. return self.model(self.to_tensor(x)).cpu().detach().numpy()
  104. def get_forward_loss_from_input(self, x, y, reduction='none'):
  105. with torch.no_grad():
  106. default_reduction = self.loss_fn.reduction
  107. self.loss_fn.reduction = reduction
  108. yt = self.to_tensor(y)
  109. xt = self.to_tensor(x)
  110. loss = self.loss_fn(self.model(xt), yt)
  111. self.loss_fn.reduction = default_reduction
  112. return list(map(float, loss.detach().numpy()))
  113. def get_input_gradients(self, x, y):
  114. yt = self.to_tensor(y)
  115. xt = self.to_tensor(x).requires_grad_(True)
  116. fw = self.model(xt)
  117. loss = self.loss_fn(fw, yt)
  118. grad = autograd.grad(loss, xt)
  119. return [grad[0].detach().numpy()]
  120. def get_loss(self):
  121. return [self.loss_history[-1]]
  122. @staticmethod
  123. def get_model_bytes(model):
  124. with tempfile.TemporaryFile() as f:
  125. torch.save(model, f)
  126. f.seek(0)
  127. return f.read()
  128. @staticmethod
  129. def recover_model_bytes(model_bytes):
  130. with tempfile.TemporaryFile() as f:
  131. f.write(model_bytes)
  132. f.seek(0)
  133. model = torch.load(f)
  134. return model
  135. @staticmethod
  136. def get_model_save_dict(model: t.nn.Module, model_define, optimizer: t.optim.Optimizer, optimizer_define,
  137. loss_define):
  138. with tempfile.TemporaryFile() as f:
  139. save_dict = {
  140. 'nn_define': model_define,
  141. 'model': model.state_dict(),
  142. 'optimizer_define': optimizer_define,
  143. 'optimizer': optimizer.state_dict(),
  144. 'loss_define': loss_define
  145. }
  146. torch.save(save_dict, f)
  147. f.seek(0)
  148. return f.read()
  149. @staticmethod
  150. def recover_model_save_dict(model_bytes):
  151. with tempfile.TemporaryFile() as f:
  152. f.write(model_bytes)
  153. f.seek(0)
  154. save_dict = torch.load(f)
  155. return save_dict
  156. def restore_model(self, model_bytes):
  157. save_dict = self.recover_model_save_dict(model_bytes)
  158. self.nn_define = save_dict['nn_define']
  159. opt_define = save_dict['optimizer_define']
  160. # optimizer can be updated
  161. # old define == new define, load state dict
  162. if opt_define == self.optimizer_define:
  163. opt_inst: t.optim.Optimizer = self.opt_inst
  164. opt_inst.load_state_dict(save_dict['optimizer'])
  165. # load state dict
  166. self.model.load_state_dict(save_dict['model'])
  167. return self
  168. def export_model(self):
  169. return self.get_model_save_dict(
  170. self.model,
  171. self.nn_define,
  172. self.opt_inst,
  173. self.optimizer_define,
  174. self.loss_fn_define)