hetero_nn_param.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import copy
  19. import collections
  20. from types import SimpleNamespace
  21. from pipeline.param.base_param import BaseParam
  22. from pipeline.param.callback_param import CallbackParam
  23. from pipeline.param.cross_validation_param import CrossValidationParam
  24. from pipeline.param.encrypt_param import EncryptParam
  25. from pipeline.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
  26. from pipeline.param.predict_param import PredictParam
  27. from pipeline.param import consts
  28. class DatasetParam(BaseParam):
  29. def __init__(self, dataset_name=None, **kwargs):
  30. super(DatasetParam, self).__init__()
  31. self.dataset_name = dataset_name
  32. self.param = kwargs
  33. def check(self):
  34. if self.dataset_name is not None:
  35. self.check_string(self.dataset_name, 'dataset_name')
  36. def to_dict(self):
  37. ret = {'dataset_name': self.dataset_name, 'param': self.param}
  38. return ret
  39. class SelectorParam(object):
  40. """
  41. Parameters
  42. ----------
  43. method: None or str
  44. back propagation select method, accept "relative" only, default: None
  45. selective_size: int
  46. deque size to use, store the most recent selective_size historical loss, default: 1024
  47. beta: int
  48. sample whose selective probability >= power(np.random, beta) will be selected
  49. min_prob: Numeric
  50. selective probability is max(min_prob, rank_rate)
  51. """
  52. def __init__(self, method=None, beta=1, selective_size=consts.SELECTIVE_SIZE, min_prob=0, random_state=None):
  53. self.method = method
  54. self.selective_size = selective_size
  55. self.beta = beta
  56. self.min_prob = min_prob
  57. self.random_state = random_state
  58. def check(self):
  59. if self.method is not None and self.method not in ["relative"]:
  60. raise ValueError('selective method should be None be "relative"')
  61. if not isinstance(self.selective_size, int) or self.selective_size <= 0:
  62. raise ValueError("selective size should be a positive integer")
  63. if not isinstance(self.beta, int):
  64. raise ValueError("beta should be integer")
  65. if not isinstance(self.min_prob, (float, int)):
  66. raise ValueError("min_prob should be numeric")
  67. class CoAEConfuserParam(BaseParam):
  68. """
  69. A label protect mechanism proposed in paper: "Batch Label Inference and Replacement Attacks in Black-Boxed Vertical Federated Learning"
  70. paper link: https://arxiv.org/abs/2112.05409
  71. Convert true labels to fake soft labels by using an auto-encoder.
  72. Args:
  73. enable: boolean
  74. run CoAE or not
  75. epoch: None or int
  76. auto-encoder training epochs
  77. lr: float
  78. auto-encoder learning rate
  79. lambda1: float
  80. parameter to control the difference between true labels and fake soft labels. Larger the parameter,
  81. autoencoder will give more attention to making true labels and fake soft label different.
  82. lambda2: float
  83. parameter to control entropy loss, see original paper for details
  84. verbose: boolean
  85. print loss log while training auto encoder
  86. """
  87. def __init__(self, enable=False, epoch=50, lr=0.001, lambda1=1.0, lambda2=2.0, verbose=False):
  88. super(CoAEConfuserParam, self).__init__()
  89. self.enable = enable
  90. self.epoch = epoch
  91. self.lr = lr
  92. self.lambda1 = lambda1
  93. self.lambda2 = lambda2
  94. self.verbose = verbose
  95. def check(self):
  96. self.check_boolean(self.enable, 'enable')
  97. if not isinstance(self.epoch, int) or self.epoch <= 0:
  98. raise ValueError("epoch should be a positive integer")
  99. if not isinstance(self.lr, float):
  100. raise ValueError('lr should be a float number')
  101. if not isinstance(self.lambda1, float):
  102. raise ValueError('lambda1 should be a float number')
  103. if not isinstance(self.lambda2, float):
  104. raise ValueError('lambda2 should be a float number')
  105. self.check_boolean(self.verbose, 'verbose')
  106. class HeteroNNParam(BaseParam):
  107. """
  108. Parameters used for Hetero Neural Network.
  109. Parameters
  110. ----------
  111. task_type: str, task type of hetero nn model, one of 'classification', 'regression'.
  112. bottom_nn_define: a dict represents the structure of bottom neural network.
  113. interactive_layer_define: a dict represents the structure of interactive layer.
  114. interactive_layer_lr: float, the learning rate of interactive layer.
  115. top_nn_define: a dict represents the structure of top neural network.
  116. optimizer: optimizer method, accept following types:
  117. 1. a string, one of "Adadelta", "Adagrad", "Adam", "Adamax", "Nadam", "RMSprop", "SGD"
  118. 2. a dict, with a required key-value pair keyed by "optimizer",
  119. with optional key-value pairs such as learning rate.
  120. defaults to "SGD".
  121. loss: str, a string to define loss function used
  122. epochs: int, the maximum iteration for aggregation in training.
  123. batch_size : int, batch size when updating model.
  124. -1 means use all data in a batch. i.e. Not to use mini-batch strategy.
  125. defaults to -1.
  126. early_stop : str, accept 'diff' only in this version, default: 'diff'
  127. Method used to judge converge or not.
  128. a) diff: Use difference of loss between two iterations to judge whether converge.
  129. floating_point_precision: None or integer, if not None, means use floating_point_precision-bit to speed up calculation,
  130. e.g.: convert an x to round(x * 2**floating_point_precision) during Paillier operation, divide
  131. the result by 2**floating_point_precision in the end.
  132. callback_param: CallbackParam object
  133. """
  134. def __init__(self,
  135. task_type='classification',
  136. bottom_nn_define=None,
  137. top_nn_define=None,
  138. config_type='pytorch',
  139. interactive_layer_define=None,
  140. interactive_layer_lr=0.9,
  141. optimizer='SGD',
  142. loss=None,
  143. epochs=100,
  144. batch_size=-1,
  145. early_stop="diff",
  146. tol=1e-5,
  147. encrypt_param=EncryptParam(),
  148. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(),
  149. predict_param=PredictParam(),
  150. cv_param=CrossValidationParam(),
  151. validation_freqs=None,
  152. early_stopping_rounds=None,
  153. metrics=None,
  154. use_first_metric_only=True,
  155. selector_param=SelectorParam(),
  156. floating_point_precision=23,
  157. callback_param=CallbackParam(),
  158. coae_param=CoAEConfuserParam(),
  159. dataset=DatasetParam()
  160. ):
  161. super(HeteroNNParam, self).__init__()
  162. self.task_type = task_type
  163. self.bottom_nn_define = bottom_nn_define
  164. self.interactive_layer_define = interactive_layer_define
  165. self.interactive_layer_lr = interactive_layer_lr
  166. self.top_nn_define = top_nn_define
  167. self.batch_size = batch_size
  168. self.epochs = epochs
  169. self.early_stop = early_stop
  170. self.tol = tol
  171. self.optimizer = optimizer
  172. self.loss = loss
  173. self.validation_freqs = validation_freqs
  174. self.early_stopping_rounds = early_stopping_rounds
  175. self.metrics = metrics or []
  176. self.use_first_metric_only = use_first_metric_only
  177. self.encrypt_param = copy.deepcopy(encrypt_param)
  178. self.encrypted_model_calculator_param = encrypted_mode_calculator_param
  179. self.predict_param = copy.deepcopy(predict_param)
  180. self.cv_param = copy.deepcopy(cv_param)
  181. self.selector_param = selector_param
  182. self.floating_point_precision = floating_point_precision
  183. self.callback_param = copy.deepcopy(callback_param)
  184. self.coae_param = coae_param
  185. self.dataset = dataset
  186. self.config_type = 'pytorch' # pytorch only
  187. def check(self):
  188. assert isinstance(self.dataset, DatasetParam), 'dataset must be a DatasetParam()'
  189. self.dataset.check()
  190. if self.task_type not in ["classification", "regression"]:
  191. raise ValueError("config_type should be classification or regression")
  192. if not isinstance(self.tol, (int, float)):
  193. raise ValueError("tol should be numeric")
  194. if not isinstance(self.epochs, int) or self.epochs <= 0:
  195. raise ValueError("epochs should be a positive integer")
  196. if self.bottom_nn_define and not isinstance(self.bottom_nn_define, dict):
  197. raise ValueError("bottom_nn_define should be a dict defining the structure of neural network")
  198. if self.top_nn_define and not isinstance(self.top_nn_define, dict):
  199. raise ValueError("top_nn_define should be a dict defining the structure of neural network")
  200. if self.interactive_layer_define is not None and not isinstance(self.interactive_layer_define, dict):
  201. raise ValueError(
  202. "the interactive_layer_define should be a dict defining the structure of interactive layer")
  203. if self.batch_size != -1:
  204. if not isinstance(self.batch_size, int) \
  205. or self.batch_size < consts.MIN_BATCH_SIZE:
  206. raise ValueError(
  207. " {} not supported, should be larger than 10 or -1 represent for all data".format(self.batch_size))
  208. if self.early_stop != "diff":
  209. raise ValueError("early stop should be diff in this version")
  210. if self.metrics is not None and not isinstance(self.metrics, list):
  211. raise ValueError("metrics should be a list")
  212. if self.floating_point_precision is not None and \
  213. (not isinstance(self.floating_point_precision, int) or
  214. self.floating_point_precision < 0 or self.floating_point_precision > 63):
  215. raise ValueError("floating point precision should be null or a integer between 0 and 63")
  216. self.encrypt_param.check()
  217. self.encrypted_model_calculator_param.check()
  218. self.predict_param.check()
  219. self.selector_param.check()
  220. self.coae_param.check()
  221. descr = "hetero nn param's "
  222. for p in ["early_stopping_rounds", "validation_freqs",
  223. "use_first_metric_only"]:
  224. if self._deprecated_params_set.get(p):
  225. if "callback_param" in self.get_user_feeded():
  226. raise ValueError(f"{p} and callback param should not be set simultaneously,"
  227. f"{self._deprecated_params_set}, {self.get_user_feeded()}")
  228. else:
  229. self.callback_param.callbacks = ["PerformanceEvaluate"]
  230. break
  231. if self._warn_to_deprecate_param("validation_freqs", descr, "callback_param's 'validation_freqs'"):
  232. self.callback_param.validation_freqs = self.validation_freqs
  233. if self._warn_to_deprecate_param("early_stopping_rounds", descr, "callback_param's 'early_stopping_rounds'"):
  234. self.callback_param.early_stopping_rounds = self.early_stopping_rounds
  235. if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"):
  236. if self.metrics:
  237. self.callback_param.metrics = self.metrics
  238. if self._warn_to_deprecate_param("use_first_metric_only", descr, "callback_param's 'use_first_metric_only'"):
  239. self.callback_param.use_first_metric_only = self.use_first_metric_only