guest.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import numpy as np
  19. import torch
  20. from torch.utils.data import DataLoader
  21. from fate_arch.computing._util import is_table
  22. from fate_arch.session import computing_session as session
  23. from federatedml.feature.instance import Instance
  24. from federatedml.framework.hetero.procedure import batch_generator
  25. from federatedml.model_base import Metric
  26. from federatedml.model_base import MetricMeta
  27. from federatedml.nn.hetero.base import HeteroNNBase
  28. from federatedml.nn.hetero.model import HeteroNNGuestModel
  29. from federatedml.optim.convergence import converge_func_factory
  30. from federatedml.param.evaluation_param import EvaluateParam
  31. from federatedml.param.hetero_nn_param import HeteroNNParam as NNParameter
  32. from federatedml.protobuf.generated.hetero_nn_model_meta_pb2 import HeteroNNMeta
  33. from federatedml.protobuf.generated.hetero_nn_model_param_pb2 import HeteroNNParam
  34. from federatedml.util import consts, LOGGER
  35. from federatedml.util.io_check import assert_io_num_rows_equal
  36. from federatedml.nn.dataset.table import TableDataset
  37. from federatedml.statistic.data_overview import check_with_inst_id
  38. from federatedml.nn.backend.utils.data import add_match_id
  39. MODELMETA = "HeteroNNGuestMeta"
  40. MODELPARAM = "HeteroNNGuestParam"
  41. class HeteroNNGuest(HeteroNNBase):
  42. def __init__(self):
  43. super(HeteroNNGuest, self).__init__()
  44. self.task_type = None
  45. self.converge_func = None
  46. self.batch_generator = batch_generator.Guest()
  47. self.data_keys = []
  48. self.label_dict = {}
  49. self.model = None
  50. self.role = consts.GUEST
  51. self.history_loss = []
  52. self.input_shape = None
  53. self._summary_buf = {"history_loss": [],
  54. "is_converged": False,
  55. "best_iteration": -1}
  56. self.dataset_cache_dict = {}
  57. self.default_table_partitions = 4
  58. def _init_model(self, hetero_nn_param):
  59. super(HeteroNNGuest, self)._init_model(hetero_nn_param)
  60. self.task_type = hetero_nn_param.task_type
  61. self.converge_func = converge_func_factory(self.early_stop, self.tol)
  62. def _build_model(self):
  63. self.model = HeteroNNGuestModel(
  64. self.hetero_nn_param, self.component_properties, self.flowid)
  65. self.model.set_transfer_variable(self.transfer_variable)
  66. self.model.set_partition(self.default_table_partitions)
  67. def _set_loss_callback_info(self):
  68. self.callback_meta("loss",
  69. "train",
  70. MetricMeta(name="train",
  71. metric_type="LOSS",
  72. extra_metas={"unit_name": "iters"}))
  73. @staticmethod
  74. def _disable_sample_weight(dataset):
  75. # currently not support sample weight
  76. if isinstance(dataset, TableDataset):
  77. dataset.with_sample_weight = False
  78. def fit(self, data_inst, validate_data=None):
  79. if hasattr(
  80. data_inst,
  81. 'partitions') and data_inst.partitions is not None:
  82. self.default_table_partitions = data_inst.partitions
  83. LOGGER.debug(
  84. 'reset default partitions is {}'.format(
  85. self.default_table_partitions))
  86. train_ds = self.prepare_dataset(
  87. data_inst, data_type='train', check_label=True)
  88. train_ds.train() # set dataset to train mode
  89. self._disable_sample_weight(train_ds)
  90. if validate_data is not None:
  91. val_ds = self.prepare_dataset(validate_data, data_type='validate')
  92. val_ds.train() # set dataset to train mode
  93. self._disable_sample_weight(val_ds)
  94. else:
  95. val_ds = None
  96. self.callback_list.on_train_begin(train_ds, val_ds)
  97. # collect data from table to form data loader
  98. if not self.component_properties.is_warm_start:
  99. self._build_model()
  100. epoch_offset = 0
  101. else:
  102. self.callback_warm_start_init_iter(self.history_iter_epoch)
  103. epoch_offset = self.history_iter_epoch + 1
  104. # set label number
  105. self.model.set_label_num(self.label_num)
  106. if len(train_ds) == 0:
  107. self.model.set_empty()
  108. self._set_loss_callback_info()
  109. batch_size = len(train_ds) if self.batch_size == - \
  110. 1 else self.batch_size
  111. data_loader = DataLoader(
  112. train_ds,
  113. batch_size=batch_size,
  114. num_workers=4)
  115. for cur_epoch in range(epoch_offset, self.epochs + epoch_offset):
  116. self.iter_epoch = cur_epoch
  117. LOGGER.debug("cur epoch is {}".format(cur_epoch))
  118. self.callback_list.on_epoch_begin(cur_epoch)
  119. epoch_loss = 0
  120. acc_sample_num = 0
  121. for batch_idx, (batch_data, batch_label) in enumerate(data_loader):
  122. batch_loss = self.model.train(
  123. batch_data, batch_label, cur_epoch, batch_idx)
  124. if acc_sample_num + batch_size > len(train_ds):
  125. batch_len = len(train_ds) - acc_sample_num
  126. else:
  127. batch_len = batch_size
  128. acc_sample_num += batch_size
  129. epoch_loss += batch_loss * batch_len
  130. epoch_loss = epoch_loss / len(train_ds)
  131. LOGGER.debug("epoch {} loss is {}".format(cur_epoch, epoch_loss))
  132. self.callback_metric("loss",
  133. "train",
  134. [Metric(cur_epoch, epoch_loss)])
  135. self.history_loss.append(epoch_loss)
  136. self.callback_list.on_epoch_end(cur_epoch)
  137. if self.callback_variables.stop_training:
  138. LOGGER.debug('early stopping triggered')
  139. break
  140. if self.hetero_nn_param.selector_param.method:
  141. # when use selective bp, loss converge will be disabled
  142. is_converge = False
  143. else:
  144. is_converge = self.converge_func.is_converge(epoch_loss)
  145. self._summary_buf["is_converged"] = is_converge
  146. self.transfer_variable.is_converge.remote(is_converge,
  147. role=consts.HOST,
  148. idx=-1,
  149. suffix=(cur_epoch,))
  150. if is_converge:
  151. LOGGER.debug(
  152. "Training process is converged in epoch {}".format(cur_epoch))
  153. break
  154. self.callback_list.on_train_end()
  155. self.set_summary(self._get_model_summary())
  156. @assert_io_num_rows_equal
  157. def predict(self, data_inst):
  158. with_match_id = False
  159. if is_table(data_inst):
  160. with_match_id = check_with_inst_id(data_inst)
  161. ds = self.prepare_dataset(data_inst, data_type='predict')
  162. ds.eval() # set dataset to eval mode
  163. self._disable_sample_weight(ds)
  164. keys = ds.get_sample_ids()
  165. batch_size = len(ds) if self.batch_size == -1 else self.batch_size
  166. dl = DataLoader(ds, batch_size=batch_size)
  167. preds = []
  168. labels = []
  169. for batch_data, batch_label in dl:
  170. batch_pred = self.model.predict(batch_data)
  171. preds.append(batch_pred)
  172. labels.append(batch_label)
  173. preds = np.concatenate(preds, axis=0)
  174. labels = torch.concat(labels, dim=0).cpu().numpy().flatten().tolist()
  175. id_table = [(id_, Instance(label=l)) for id_, l in zip(keys, labels)]
  176. if with_match_id:
  177. add_match_id(id_table, ds.ds) # ds is wrap shuffle dataset here
  178. data_inst = session.parallelize(
  179. id_table,
  180. partition=self.default_table_partitions,
  181. include_key=True)
  182. if self.task_type == consts.REGRESSION:
  183. preds = preds.flatten().tolist()
  184. preds = [float(pred) for pred in preds]
  185. predict_tb = session.parallelize(zip(keys, preds), include_key=True,
  186. partition=self.default_table_partitions)
  187. result = self.predict_score_to_output(data_inst, predict_tb)
  188. else:
  189. if self.label_num > 2:
  190. preds = preds.tolist()
  191. preds = [list(map(float, pred)) for pred in preds]
  192. predict_tb = session.parallelize(zip(keys, preds), include_key=True,
  193. partition=self.default_table_partitions)
  194. result = self.predict_score_to_output(
  195. data_inst, predict_tb, classes=list(range(self.label_num)))
  196. else:
  197. preds = preds.flatten().tolist()
  198. preds = [float(pred) for pred in preds]
  199. predict_tb = session.parallelize(zip(keys, preds), include_key=True,
  200. partition=self.default_table_partitions)
  201. threshold = self.predict_param.threshold
  202. result = self.predict_score_to_output(
  203. data_inst, predict_tb, classes=[
  204. 0, 1], threshold=threshold)
  205. return result
  206. def export_model(self):
  207. if self.need_cv:
  208. return None
  209. model = {MODELMETA: self._get_model_meta(),
  210. MODELPARAM: self._get_model_param()}
  211. return model
  212. def load_model(self, model_dict):
  213. model_dict = list(model_dict["model"].values())[0]
  214. param = model_dict.get(MODELPARAM)
  215. meta = model_dict.get(MODELMETA)
  216. if self.hetero_nn_param is None:
  217. self.hetero_nn_param = NNParameter()
  218. self.hetero_nn_param.check()
  219. self.predict_param = self.hetero_nn_param.predict_param
  220. self._build_model()
  221. self._restore_model_meta(meta)
  222. self._restore_model_param(param)
  223. def _get_model_summary(self):
  224. self._summary_buf["history_loss"] = self.history_loss
  225. if self.callback_variables.validation_summary:
  226. self._summary_buf["validation_metrics"] = self.callback_variables.validation_summary
  227. """
  228. if self.validation_strategy:
  229. validation_summary = self.validation_strategy.summary()
  230. if validation_summary:
  231. self._summary_buf["validation_metrics"] = validation_summary
  232. """
  233. return self._summary_buf
  234. def _get_model_meta(self):
  235. model_meta = HeteroNNMeta()
  236. model_meta.task_type = self.task_type
  237. model_meta.module = 'HeteroNN'
  238. model_meta.batch_size = self.batch_size
  239. model_meta.epochs = self.epochs
  240. model_meta.early_stop = self.early_stop
  241. model_meta.tol = self.tol
  242. model_meta.hetero_nn_model_meta.CopyFrom(
  243. self.model.get_hetero_nn_model_meta())
  244. return model_meta
  245. def _get_model_param(self):
  246. model_param = HeteroNNParam()
  247. model_param.iter_epoch = self.iter_epoch
  248. model_param.hetero_nn_model_param.CopyFrom(
  249. self.model.get_hetero_nn_model_param())
  250. model_param.num_label = self.label_num
  251. model_param.best_iteration = self.callback_variables.best_iteration
  252. model_param.header.extend(self._header)
  253. for loss in self.history_loss:
  254. model_param.history_loss.append(loss)
  255. return model_param
  256. def get_metrics_param(self):
  257. if self.task_type == consts.CLASSIFICATION:
  258. if self.label_num == 2:
  259. return EvaluateParam(eval_type="binary",
  260. pos_label=1, metrics=self.metrics)
  261. else:
  262. return EvaluateParam(eval_type="multi", metrics=self.metrics)
  263. else:
  264. return EvaluateParam(eval_type="regression", metrics=self.metrics)
  265. def _restore_model_param(self, param):
  266. super(HeteroNNGuest, self)._restore_model_param(param)
  267. self.label_num = param.num_label