bottom_model.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import torch as t
  19. import numpy as np
  20. from federatedml.util import LOGGER
  21. from federatedml.nn.hetero.nn_component.torch_model import TorchNNModel
  22. class BottomModel(object):
  23. def __init__(self, optimizer, layer_config):
  24. self._model: TorchNNModel = TorchNNModel(nn_define=layer_config, optimizer_define=optimizer,
  25. loss_fn_define=None)
  26. self.do_backward_select_strategy = False
  27. self.x = []
  28. self.x_cached = []
  29. self.batch_size = None
  30. def set_backward_select_strategy(self):
  31. self.do_backward_select_strategy = True
  32. def set_batch(self, batch_size):
  33. self.batch_size = batch_size
  34. def train_mode(self, mode):
  35. self._model.train_mode(mode)
  36. def forward(self, x):
  37. LOGGER.debug("bottom model start to forward propagation")
  38. self.x = x
  39. if self.do_backward_select_strategy:
  40. if (not isinstance(x, np.ndarray) and not isinstance(x, t.Tensor)):
  41. raise ValueError(
  42. 'When using selective bp, data from dataset must be a ndarray or a torch tensor, but got {}'.format(
  43. type(x)))
  44. if self.do_backward_select_strategy:
  45. output_data = self._model.predict(x)
  46. else:
  47. output_data = self._model.forward(x)
  48. return output_data
  49. def backward(self, x, error, selective_ids):
  50. LOGGER.debug("bottom model start to backward propagation")
  51. if self.do_backward_select_strategy:
  52. if selective_ids:
  53. if len(self.x_cached) == 0:
  54. self.x_cached = self.x[selective_ids]
  55. else:
  56. self.x_cached = np.vstack(
  57. (self.x_cached, self.x[selective_ids]))
  58. if len(error) == 0:
  59. return
  60. x = self.x_cached[: self.batch_size]
  61. self.x_cached = self.x_cached[self.batch_size:]
  62. self._model.train((x, error))
  63. else:
  64. self._model.backward(error)
  65. LOGGER.debug('bottom model update parameters:')
  66. def predict(self, x):
  67. return self._model.predict(x)
  68. def export_model(self):
  69. return self._model.export_model()
  70. def restore_model(self, model_bytes):
  71. self._model = self._model.restore_model(model_bytes)
  72. def __repr__(self):
  73. return 'bottom model contains {}'.format(self._model.__repr__())