base.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. from fate_arch.computing.non_distributed import LocalData
  19. from federatedml.model_base import ModelBase
  20. from federatedml.model_selection import start_cross_validation
  21. from federatedml.nn.backend.utils.data import load_dataset
  22. from federatedml.nn.dataset.base import Dataset, ShuffleWrapDataset
  23. from federatedml.param.hetero_nn_param import HeteroNNParam
  24. from federatedml.transfer_variable.transfer_class.hetero_nn_transfer_variable import HeteroNNTransferVariable
  25. from federatedml.util import consts
  26. class HeteroNNBase(ModelBase):
  27. def __init__(self):
  28. super(HeteroNNBase, self).__init__()
  29. self.tol = None
  30. self.early_stop = None
  31. self.seed = 100
  32. self.epochs = None
  33. self.batch_size = None
  34. self._header = []
  35. self.predict_param = None
  36. self.hetero_nn_param = None
  37. self.batch_generator = None
  38. self.model = None
  39. self.partition = None
  40. self.validation_freqs = None
  41. self.early_stopping_rounds = None
  42. self.metrics = []
  43. self.use_first_metric_only = False
  44. self.transfer_variable = HeteroNNTransferVariable()
  45. self.model_param = HeteroNNParam()
  46. self.mode = consts.HETERO
  47. self.selector_param = None
  48. self.floating_point_precision = None
  49. self.history_iter_epoch = 0
  50. self.iter_epoch = 0
  51. self.data_x = []
  52. self.data_y = []
  53. self.dataset_cache_dict = {}
  54. self.label_num = None
  55. # nn related param
  56. self.top_model_define = None
  57. self.bottom_model_define = None
  58. self.interactive_layer_define = None
  59. self.dataset_shuffle = True
  60. self.dataset = None
  61. self.dataset_param = None
  62. self.dataset_shuffle_seed = 100
  63. def _init_model(self, hetero_nn_param: HeteroNNParam):
  64. self.interactive_layer_lr = hetero_nn_param.interactive_layer_lr
  65. self.epochs = hetero_nn_param.epochs
  66. self.batch_size = hetero_nn_param.batch_size
  67. self.seed = hetero_nn_param.seed
  68. self.early_stop = hetero_nn_param.early_stop
  69. self.validation_freqs = hetero_nn_param.validation_freqs
  70. self.early_stopping_rounds = hetero_nn_param.early_stopping_rounds
  71. self.metrics = hetero_nn_param.metrics
  72. self.use_first_metric_only = hetero_nn_param.use_first_metric_only
  73. self.tol = hetero_nn_param.tol
  74. self.predict_param = hetero_nn_param.predict_param
  75. self.hetero_nn_param = hetero_nn_param
  76. self.selector_param = hetero_nn_param.selector_param
  77. self.floating_point_precision = hetero_nn_param.floating_point_precision
  78. # nn configs
  79. self.bottom_model_define = hetero_nn_param.bottom_nn_define
  80. self.top_model_define = hetero_nn_param.top_nn_define
  81. self.interactive_layer_define = hetero_nn_param.interactive_layer_define
  82. # dataset
  83. dataset_param = hetero_nn_param.dataset.to_dict()
  84. self.dataset = dataset_param['dataset_name']
  85. self.dataset_param = dataset_param['param']
  86. def reset_flowid(self):
  87. new_flowid = ".".join([self.flowid, "evaluate"])
  88. self.set_flowid(new_flowid)
  89. def recovery_flowid(self):
  90. new_flowid = ".".join(self.flowid.split(".", -1)[: -1])
  91. self.set_flowid(new_flowid)
  92. def _build_bottom_model(self):
  93. pass
  94. def _build_interactive_model(self):
  95. pass
  96. def _restore_model_meta(self, meta):
  97. # self.hetero_nn_param.interactive_layer_lr = meta.interactive_layer_lr
  98. self.hetero_nn_param.task_type = meta.task_type
  99. if not self.component_properties.is_warm_start:
  100. self.batch_size = meta.batch_size
  101. self.epochs = meta.epochs
  102. self.tol = meta.tol
  103. self.early_stop = meta.early_stop
  104. self.model.set_hetero_nn_model_meta(meta.hetero_nn_model_meta)
  105. def _restore_model_param(self, param):
  106. self.model.set_hetero_nn_model_param(param.hetero_nn_model_param)
  107. self._header = list(param.header)
  108. self.history_iter_epoch = param.iter_epoch
  109. self.iter_epoch = param.iter_epoch
  110. def set_partition(self, data_inst):
  111. self.partition = data_inst.partitions
  112. self.model.set_partition(self.partition)
  113. def cross_validation(self, data_instances):
  114. return start_cross_validation.run(self, data_instances)
  115. def prepare_dataset(self, data, data_type='train', check_label=False):
  116. # train input & validate input are DTables or path str
  117. if isinstance(data, LocalData):
  118. data = data.path
  119. if isinstance(data, Dataset) or isinstance(data, ShuffleWrapDataset):
  120. ds = data
  121. else:
  122. ds = load_dataset(
  123. self.dataset,
  124. data,
  125. self.dataset_param,
  126. self.dataset_cache_dict)
  127. if not ds.has_sample_ids():
  128. raise ValueError(
  129. 'Dataset has no sample id, this is not allowed in hetero-nn, please make sure'
  130. ' that you implement get_sample_ids()')
  131. if self.dataset_shuffle:
  132. ds = ShuffleWrapDataset(
  133. ds, shuffle_seed=self.dataset_shuffle_seed)
  134. if self.role == consts.GUEST:
  135. self.transfer_variable.dataset_info.remote(
  136. ds.idx_map, idx=-1, suffix=('idx_map', data_type))
  137. if self.role == consts.HOST:
  138. idx_map = self.transfer_variable.dataset_info.get(
  139. idx=0, suffix=('idx_map', data_type))
  140. assert len(idx_map) == len(ds), 'host dataset len != guest dataset len, please check your dataset,' \
  141. 'guest len {}, host len {}'.format(len(idx_map), len(ds))
  142. ds.set_shuffled_idx(idx_map)
  143. if check_label:
  144. try:
  145. all_classes = ds.get_classes()
  146. except NotImplementedError as e:
  147. raise NotImplementedError(
  148. 'get_classes() is not implemented, please implement this function'
  149. ' when you are using hetero-nn. Let it return classes in a list.'
  150. ' Please see built-in dataset(table.py for example) for reference')
  151. except BaseException as e:
  152. raise e
  153. from federatedml.util import LOGGER
  154. LOGGER.debug('all classes is {}'.format(all_classes))
  155. if self.label_num is None:
  156. if self.task_type == consts.CLASSIFICATION:
  157. self.label_num = len(all_classes)
  158. elif self.task_type == consts.REGRESSION:
  159. self.label_num = 1
  160. return ds
  161. # override function
  162. @staticmethod
  163. def set_predict_data_schema(predict_datas, schemas):
  164. if predict_datas is None:
  165. return predict_datas
  166. if isinstance(predict_datas, list):
  167. predict_data = predict_datas[0]
  168. schema = schemas[0]
  169. else:
  170. predict_data = predict_datas
  171. schema = schemas
  172. if predict_data is not None:
  173. predict_data.schema = {
  174. "header": [
  175. "label",
  176. "predict_result",
  177. "predict_score",
  178. "predict_detail",
  179. "type",
  180. ],
  181. "sid": 'id',
  182. "content_type": "predict_result"
  183. }
  184. if schema.get("match_id_name") is not None:
  185. predict_data.schema["match_id_name"] = schema.get(
  186. "match_id_name")
  187. return predict_data