top_model.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import numpy as np
  19. import torch
  20. from federatedml.nn.hetero.nn_component.torch_model import TorchNNModel
  21. from federatedml.nn.hetero.protection_enhance.coae import train_an_autoencoder_confuser, CoAE, coae_label_reformat, \
  22. CrossEntropy
  23. from federatedml.util import LOGGER
  24. class TopModel(object):
  25. def __init__(self, loss, optimizer, layer_config, coae_config, label_num):
  26. self.coae = None
  27. self.coae_config = coae_config
  28. self.label_num = label_num
  29. LOGGER.debug('label num is {}'.format(self.label_num))
  30. self._model: TorchNNModel = TorchNNModel(nn_define=layer_config, optimizer_define=optimizer,
  31. loss_fn_define=loss)
  32. self.label_reformat = None
  33. if self.coae_config:
  34. self._model.loss_fn = CrossEntropy()
  35. if self.coae_config:
  36. self.label_reformat = coae_label_reformat
  37. self.batch_size = None
  38. self.selector = None
  39. self.batch_data_cached_X = []
  40. self.batch_data_cached_y = []
  41. def set_backward_selector_strategy(self, selector):
  42. self.selector = selector
  43. def set_batch(self, batch_size):
  44. self.batch_size = batch_size
  45. def train_mode(self, mode):
  46. self._model.train_mode(mode)
  47. def train_and_get_backward_gradient(self, x, y):
  48. LOGGER.debug("top model start to forward propagation")
  49. selective_id = []
  50. input_gradient = []
  51. # transform label format
  52. if self.label_reformat:
  53. y = self.label_reformat(y, label_num=self.label_num)
  54. # train an auto-encoder confuser
  55. if self.coae_config and self.coae is None:
  56. LOGGER.debug('training coae encoder')
  57. self.coae: CoAE = train_an_autoencoder_confuser(y.shape[1], self.coae_config.epoch,
  58. self.coae_config.lambda1, self.coae_config.lambda2,
  59. self.coae_config.lr, self.coae_config.verbose)
  60. # make fake soft label
  61. if self.coae:
  62. # transform labels to fake labels
  63. y = self.coae.encode(y).detach().numpy()
  64. LOGGER.debug('fake labels are {}'.format(y))
  65. # run selector
  66. if self.selector:
  67. # when run selective bp, need to convert y to numpy format
  68. if isinstance(y, torch.Tensor):
  69. y = y.cpu().numpy()
  70. losses = self._model.get_forward_loss_from_input(x, y)
  71. loss = sum(losses) / len(losses)
  72. selective_strategy = self.selector.select_batch_sample(losses)
  73. for idx, select in enumerate(selective_strategy):
  74. if select:
  75. selective_id.append(idx)
  76. self.batch_data_cached_X.append(x[idx])
  77. self.batch_data_cached_y.append(y[idx])
  78. if len(self.batch_data_cached_X) >= self.batch_size:
  79. data = (np.array(self.batch_data_cached_X[: self.batch_size]),
  80. np.array(self.batch_data_cached_y[: self.batch_size]))
  81. input_gradient = self._model.get_input_gradients(data[0], data[1])[
  82. 0]
  83. self._model.train(data)
  84. self.batch_data_cached_X = self.batch_data_cached_X[self.batch_size:]
  85. self.batch_data_cached_y = self.batch_data_cached_y[self.batch_size:]
  86. else:
  87. input_gradient = self._model.get_input_gradients(x, y)[0]
  88. self._model.train((x, y))
  89. loss = self._model.get_loss()[0]
  90. return selective_id, input_gradient, loss
  91. def predict(self, input_data):
  92. output_data = self._model.predict(input_data)
  93. if self.coae:
  94. real_output = self.coae.decode(output_data).detach().numpy()
  95. if real_output.shape[1] == 2:
  96. real_output = real_output[::, 1].reshape((-1, 1))
  97. return real_output
  98. else:
  99. return output_data
  100. def export_coae(self):
  101. if self.coae:
  102. model_bytes = TorchNNModel.get_model_bytes(self.coae)
  103. return model_bytes
  104. else:
  105. return None
  106. def restore_coae(self, model_bytes):
  107. if model_bytes is not None and len(model_bytes) > 0:
  108. coae = TorchNNModel.recover_model_bytes(model_bytes)
  109. self.coae = coae
  110. def export_model(self):
  111. return self._model.export_model()
  112. def restore_model(self, model_bytes):
  113. self._model = self._model.restore_model(model_bytes)
  114. def __repr__(self):
  115. return 'top model contains {}'.format(self._model.__repr__())