host.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from torch.utils.data import DataLoader
  17. from federatedml.framework.hetero.procedure import batch_generator
  18. from federatedml.nn.hetero.base import HeteroNNBase
  19. from federatedml.nn.hetero.model import HeteroNNHostModel
  20. from federatedml.param.hetero_nn_param import HeteroNNParam as NNParameter
  21. from federatedml.protobuf.generated.hetero_nn_model_meta_pb2 import HeteroNNMeta
  22. from federatedml.protobuf.generated.hetero_nn_model_param_pb2 import HeteroNNParam
  23. from federatedml.util import consts, LOGGER
  24. MODELMETA = "HeteroNNHostMeta"
  25. MODELPARAM = "HeteroNNHostParam"
  26. class HeteroNNHost(HeteroNNBase):
  27. def __init__(self):
  28. super(HeteroNNHost, self).__init__()
  29. self.batch_generator = batch_generator.Host()
  30. self.model = None
  31. self.role = consts.HOST
  32. self.input_shape = None
  33. self.default_table_partitions = 4
  34. def _init_model(self, hetero_nn_param):
  35. super(HeteroNNHost, self)._init_model(hetero_nn_param)
  36. def export_model(self):
  37. if self.need_cv:
  38. return None
  39. model = {MODELMETA: self._get_model_meta(),
  40. MODELPARAM: self._get_model_param()}
  41. return model
  42. def load_model(self, model_dict):
  43. model_dict = list(model_dict["model"].values())[0]
  44. param = model_dict.get(MODELPARAM)
  45. meta = model_dict.get(MODELMETA)
  46. if self.hetero_nn_param is None:
  47. self.hetero_nn_param = NNParameter()
  48. self.hetero_nn_param.check()
  49. self.predict_param = self.hetero_nn_param.predict_param
  50. self._build_model()
  51. self._restore_model_meta(meta)
  52. self._restore_model_param(param)
  53. def _build_model(self):
  54. self.model = HeteroNNHostModel(self.hetero_nn_param, self.flowid)
  55. self.model.set_transfer_variable(self.transfer_variable)
  56. self.model.set_partition(self.default_table_partitions)
  57. def predict(self, data_inst):
  58. ds = self.prepare_dataset(data_inst, data_type='predict')
  59. batch_size = len(ds) if self.batch_size == -1 else self.batch_size
  60. for batch_data in DataLoader(ds, batch_size=batch_size):
  61. # ignore label if the dataset offers label
  62. if isinstance(batch_data, tuple) and len(batch_data) > 1:
  63. batch_data = batch_data[0]
  64. self.model.predict(batch_data)
  65. def fit(self, data_inst, validate_data=None):
  66. if hasattr(
  67. data_inst,
  68. 'partitions') and data_inst.partitions is not None:
  69. self.default_table_partitions = data_inst.partitions
  70. LOGGER.debug(
  71. 'reset default partitions is {}'.format(
  72. self.default_table_partitions))
  73. train_ds = self.prepare_dataset(data_inst, data_type='train')
  74. if validate_data is not None:
  75. val_ds = self.prepare_dataset(validate_data, data_type='validate')
  76. else:
  77. val_ds = None
  78. self.callback_list.on_train_begin(train_ds, val_ds)
  79. if not self.component_properties.is_warm_start:
  80. self._build_model()
  81. epoch_offset = 0
  82. else:
  83. self.callback_warm_start_init_iter(self.history_iter_epoch)
  84. epoch_offset = self.history_iter_epoch + 1
  85. batch_size = len(train_ds) if self.batch_size == - \
  86. 1 else self.batch_size
  87. for cur_epoch in range(epoch_offset, epoch_offset + self.epochs):
  88. self.iter_epoch = cur_epoch
  89. for batch_idx, batch_data in enumerate(
  90. DataLoader(train_ds, batch_size=batch_size)):
  91. self.model.train(batch_data, cur_epoch, batch_idx)
  92. self.callback_list.on_epoch_end(cur_epoch)
  93. if self.callback_variables.stop_training:
  94. LOGGER.debug('early stopping triggered')
  95. break
  96. is_converge = self.transfer_variable.is_converge.get(
  97. idx=0, suffix=(cur_epoch,))
  98. if is_converge:
  99. LOGGER.debug(
  100. "Training process is converged in epoch {}".format(cur_epoch))
  101. break
  102. self.callback_list.on_train_end()
  103. def _get_model_meta(self):
  104. model_meta = HeteroNNMeta()
  105. model_meta.batch_size = self.batch_size
  106. model_meta.hetero_nn_model_meta.CopyFrom(
  107. self.model.get_hetero_nn_model_meta())
  108. model_meta.module = 'HeteroNN'
  109. return model_meta
  110. def _get_model_param(self):
  111. model_param = HeteroNNParam()
  112. model_param.iter_epoch = self.iter_epoch
  113. model_param.header.extend(self._header)
  114. model_param.hetero_nn_model_param.CopyFrom(
  115. self.model.get_hetero_nn_model_param())
  116. model_param.best_iteration = self.callback_variables.best_iteration
  117. return model_param