hetero_nn.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from pipeline.component.component_base import FateComponent
  17. from pipeline.component.nn.models.sequantial import Sequential
  18. from pipeline.component.nn.backend.torch.interactive import InteractiveLayer
  19. from pipeline.interface import Input
  20. from pipeline.interface import Output
  21. from pipeline.utils.tools import extract_explicit_parameter
  22. from pipeline.component.nn.interface import DatasetParam
  23. class HeteroNN(FateComponent):
  24. @extract_explicit_parameter
  25. def __init__(self, task_type="classification", epochs=None, batch_size=-1, early_stop="diff",
  26. tol=1e-5, encrypt_param=None, predict_param=None, cv_param=None, interactive_layer_lr=0.1,
  27. validation_freqs=None, early_stopping_rounds=None, use_first_metric_only=None,
  28. floating_point_precision=23, selector_param=None, seed=100,
  29. dataset: DatasetParam = DatasetParam(dataset_name='table'), **kwargs
  30. ):
  31. """
  32. Parameters used for Hetero Neural Network.
  33. Parameters
  34. ----------
  35. task_type: str, task type of hetero nn model, one of 'classification', 'regression'.
  36. interactive_layer_lr: float, the learning rate of interactive layer.
  37. epochs: int, the maximum iteration for aggregation in training.
  38. batch_size : int, batch size when updating model.
  39. -1 means use all data in a batch. i.e. Not to use mini-batch strategy.
  40. defaults to -1.
  41. early_stop : str, accept 'diff' only in this version, default: 'diff'
  42. Method used to judge converge or not.
  43. a) diff: Use difference of loss between two iterations to judge whether converge.
  44. tol: float, tolerance val for early stop
  45. floating_point_precision: None or integer, if not None, means use floating_point_precision-bit to speed up calculation,
  46. e.g.: convert an x to round(x * 2**floating_point_precision) during Paillier operation, divide
  47. the result by 2**floating_point_precision in the end.
  48. callback_param: dict, CallbackParam, see federatedml/param/callback_param
  49. encrypt_param: dict, see federatedml/param/encrypt_param
  50. dataset_param: dict, interface defining the dataset param
  51. early_stopping_rounds: integer larger than 0
  52. will stop training if one metric of one validation data
  53. doesn’t improve in last early_stopping_round rounds,
  54. need to set validation freqs and will check early_stopping every at every validation epoch
  55. validation_freqs: None or positive integer or container object in python
  56. Do validation in training process or Not.
  57. if equals None, will not do validation in train process;
  58. if equals positive integer, will validate data every validation_freqs epochs passes;
  59. if container object in python, will validate data if epochs belong to this container.
  60. e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
  61. Default: None
  62. """
  63. explicit_parameters = kwargs["explict_parameters"]
  64. explicit_parameters["optimizer"] = None
  65. explicit_parameters["bottom_nn_define"] = None
  66. explicit_parameters["top_nn_define"] = None
  67. explicit_parameters["interactive_layer_define"] = None
  68. explicit_parameters["loss"] = None
  69. FateComponent.__init__(self, **explicit_parameters)
  70. if "name" in explicit_parameters:
  71. del explicit_parameters["name"]
  72. for param_key, param_value in explicit_parameters.items():
  73. setattr(self, param_key, param_value)
  74. self.input = Input(self.name, data_type="multi")
  75. self.output = Output(self.name, data_type='single')
  76. self._module_name = "HeteroNN"
  77. self.optimizer = None
  78. self.bottom_nn_define = None
  79. self.top_nn_define = None
  80. self.interactive_layer_define = None
  81. # model holder
  82. self._bottom_nn_model = Sequential()
  83. self._interactive_layer = Sequential()
  84. self._top_nn_model = Sequential()
  85. # role
  86. self._role = 'common' # common/guest/host
  87. if hasattr(self, 'dataset'):
  88. assert isinstance(
  89. self.dataset, DatasetParam), 'dataset must be a DatasetParam class'
  90. self.dataset.check()
  91. self.dataset: DatasetParam = self.dataset.to_dict()
  92. def set_role(self, role):
  93. self._role = role
  94. def get_party_instance(self, role="guest", party_id=None) -> 'Component':
  95. inst = super().get_party_instance(role, party_id)
  96. inst.set_role(role)
  97. return inst
  98. def add_dataset(self, dataset_param: DatasetParam):
  99. assert isinstance(
  100. dataset_param, DatasetParam), 'dataset must be a DatasetParam class'
  101. dataset_param.check()
  102. self.dataset: DatasetParam = dataset_param.to_dict()
  103. self._component_parameter_keywords.add("dataset")
  104. self._component_param["dataset"] = self.dataset
  105. def add_bottom_model(self, model):
  106. if not hasattr(self, "_bottom_nn_model"):
  107. setattr(self, "_bottom_nn_model", Sequential())
  108. self._bottom_nn_model.add(model)
  109. def set_interactive_layer(self, layer):
  110. if self._role == 'common' or self._role == 'guest':
  111. if not hasattr(self, "_interactive_layer"):
  112. setattr(self, "_interactive_layer", Sequential())
  113. assert isinstance(layer, InteractiveLayer), 'You need to add an interactive layer instance, \n' \
  114. 'you can access InteractiveLayer by:\n' \
  115. 't.nn.InteractiveLayer after fate_torch_hook(t)\n' \
  116. 'or from pipeline.component.nn.backend.torch.interactive ' \
  117. 'import InteractiveLayer'
  118. self._interactive_layer.add(layer)
  119. else:
  120. raise RuntimeError(
  121. 'You can only set interactive layer in "common" or "guest" hetero nn component')
  122. def add_top_model(self, model):
  123. if self._role == 'host':
  124. raise RuntimeError('top model is not allow to set on host model')
  125. if not hasattr(self, "_top_nn_model"):
  126. setattr(self, "_top_nn_model", Sequential())
  127. self._top_nn_model.add(model)
  128. def _set_optimizer(self, opt):
  129. assert hasattr(
  130. opt, 'to_dict'), 'opt does not have function to_dict(), remember to call fate_torch_hook(t)'
  131. self.optimizer = opt.to_dict()
  132. def _set_loss(self, loss):
  133. assert hasattr(
  134. loss, 'to_dict'), 'loss does not have function to_dict(), remember to call fate_torch_hook(t)'
  135. loss_conf = loss.to_dict()
  136. setattr(self, "loss", loss_conf)
  137. def compile(self, optimizer, loss):
  138. self._set_optimizer(optimizer)
  139. self._set_loss(loss)
  140. self._compile_common_network_config()
  141. self._compile_role_network_config()
  142. self._compile_interactive_layer()
  143. def _compile_interactive_layer(self):
  144. if hasattr(
  145. self,
  146. "_interactive_layer") and not self._interactive_layer.is_empty():
  147. self.interactive_layer_define = self._interactive_layer.get_network_config()
  148. self._component_param["interactive_layer_define"] = self.interactive_layer_define
  149. def _compile_common_network_config(self):
  150. if hasattr(
  151. self,
  152. "_bottom_nn_model") and not self._bottom_nn_model.is_empty():
  153. self.bottom_nn_define = self._bottom_nn_model.get_network_config()
  154. self._component_param["bottom_nn_define"] = self.bottom_nn_define
  155. if hasattr(
  156. self,
  157. "_top_nn_model") and not self._top_nn_model.is_empty():
  158. self.top_nn_define = self._top_nn_model.get_network_config()
  159. self._component_param["top_nn_define"] = self.top_nn_define
  160. def _compile_role_network_config(self):
  161. all_party_instance = self._get_all_party_instance()
  162. for role in all_party_instance:
  163. for party in all_party_instance[role]["party"].keys():
  164. all_party_instance[role]["party"][party]._compile_common_network_config(
  165. )
  166. all_party_instance[role]["party"][party]._compile_interactive_layer(
  167. )
  168. def get_bottom_model(self):
  169. if hasattr(
  170. self,
  171. "_bottom_nn_model") and not getattr(
  172. self,
  173. "_bottom_nn_model").is_empty():
  174. return getattr(self, "_bottom_nn_model").get_model()
  175. bottom_models = {}
  176. all_party_instance = self._get_all_party_instance()
  177. for role in all_party_instance.keys():
  178. for party in all_party_instance[role]["party"].keys():
  179. party_inst = all_party_instance[role]["party"][party]
  180. if party_inst is not None:
  181. btn_model = all_party_instance[role]["party"][party].get_bottom_model(
  182. )
  183. if btn_model is not None:
  184. bottom_models[party] = btn_model
  185. return bottom_models if len(bottom_models) > 0 else None
  186. def get_top_model(self):
  187. if hasattr(
  188. self,
  189. "_top_nn_model") and not getattr(
  190. self,
  191. "_top_nn_model").is_empty():
  192. return getattr(self, "_top_nn_model").get_model()
  193. models = {}
  194. all_party_instance = self._get_all_party_instance()
  195. for role in all_party_instance.keys():
  196. for party in all_party_instance[role]["party"].keys():
  197. party_inst = all_party_instance[role]["party"][party]
  198. if party_inst is not None:
  199. top_model = all_party_instance[role]["party"][party].get_top_model(
  200. )
  201. if top_model is not None:
  202. models[party] = top_model
  203. return models if len(models) > 0 else None
  204. def __getstate__(self):
  205. state = dict(self.__dict__)
  206. if "_bottom_nn_model" in state:
  207. del state["_bottom_nn_model"]
  208. if "_interactive_layer" in state:
  209. del state["_interactive_layer"]
  210. if "_top_nn_model" in state:
  211. del state["_top_nn_model"]
  212. return state