ftl_param.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import collections
  19. import copy
  20. from federatedml.param.intersect_param import IntersectParam
  21. from types import SimpleNamespace
  22. from federatedml.param.base_param import BaseParam, deprecated_param
  23. from federatedml.util import consts
  24. from federatedml.param.encrypt_param import EncryptParam
  25. from federatedml.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
  26. from federatedml.param.predict_param import PredictParam
  27. from federatedml.param.callback_param import CallbackParam
  28. deprecated_param_list = ["validation_freqs", "metrics"]
  29. @deprecated_param(*deprecated_param_list)
  30. class FTLParam(BaseParam):
  31. def __init__(self, alpha=1, tol=0.000001,
  32. n_iter_no_change=False, validation_freqs=None, optimizer={'optimizer': 'Adam', 'learning_rate': 0.01},
  33. nn_define={}, epochs=1, intersect_param=IntersectParam(consts.RSA), config_type='keras', batch_size=-1,
  34. encrypte_param=EncryptParam(),
  35. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(mode="confusion_opt"),
  36. predict_param=PredictParam(), mode='plain', communication_efficient=False,
  37. local_round=5, callback_param=CallbackParam()):
  38. """
  39. Parameters
  40. ----------
  41. alpha : float
  42. a loss coefficient defined in paper, it defines the importance of alignment loss
  43. tol : float
  44. loss tolerance
  45. n_iter_no_change : bool
  46. check loss convergence or not
  47. validation_freqs : None or positive integer or container object in python
  48. Do validation in training process or Not.
  49. if equals None, will not do validation in train process;
  50. if equals positive integer, will validate data every validation_freqs epochs passes;
  51. if container object in python, will validate data if epochs belong to this container.
  52. e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
  53. The default value is None, 1 is suggested. You can set it to a number larger than 1 in order to
  54. speed up training by skipping validation rounds. When it is larger than 1, a number which is
  55. divisible by "epochs" is recommended, otherwise, you will miss the validation scores
  56. of last training epoch.
  57. optimizer : str or dict
  58. optimizer method, accept following types:
  59. 1. a string, one of "Adadelta", "Adagrad", "Adam", "Adamax", "Nadam", "RMSprop", "SGD"
  60. 2. a dict, with a required key-value pair keyed by "optimizer",
  61. with optional key-value pairs such as learning rate.
  62. defaults to "SGD"
  63. nn_define : dict
  64. a dict represents the structure of neural network, it can be output by tf-keras
  65. epochs : int
  66. epochs num
  67. intersect_param
  68. define the intersect method
  69. config_type : {'tf-keras'}
  70. config type
  71. batch_size : int
  72. batch size when computing transformed feature embedding, -1 use full data.
  73. encrypte_param
  74. encrypted param
  75. encrypted_mode_calculator_param
  76. encrypted mode calculator param:
  77. predict_param
  78. predict param
  79. mode: {"plain", "encrypted"}
  80. plain: will not use any encrypt algorithms, data exchanged in plaintext
  81. encrypted: use paillier to encrypt gradients
  82. communication_efficient: bool
  83. will use communication efficient or not. when communication efficient is enabled, FTL model will
  84. update gradients by several local rounds using intermediate data
  85. local_round: int
  86. local update round when using communication efficient
  87. """
  88. super(FTLParam, self).__init__()
  89. self.alpha = alpha
  90. self.tol = tol
  91. self.n_iter_no_change = n_iter_no_change
  92. self.validation_freqs = validation_freqs
  93. self.optimizer = optimizer
  94. self.nn_define = nn_define
  95. self.epochs = epochs
  96. self.intersect_param = copy.deepcopy(intersect_param)
  97. self.config_type = config_type
  98. self.batch_size = batch_size
  99. self.encrypted_mode_calculator_param = copy.deepcopy(encrypted_mode_calculator_param)
  100. self.encrypt_param = copy.deepcopy(encrypte_param)
  101. self.predict_param = copy.deepcopy(predict_param)
  102. self.mode = mode
  103. self.communication_efficient = communication_efficient
  104. self.local_round = local_round
  105. self.callback_param = copy.deepcopy(callback_param)
  106. def check(self):
  107. self.intersect_param.check()
  108. self.encrypt_param.check()
  109. self.encrypted_mode_calculator_param.check()
  110. self.optimizer = self._parse_optimizer(self.optimizer)
  111. supported_config_type = ["keras"]
  112. if self.config_type not in supported_config_type:
  113. raise ValueError(f"config_type should be one of {supported_config_type}")
  114. if not isinstance(self.tol, (int, float)):
  115. raise ValueError("tol should be numeric")
  116. if not isinstance(self.epochs, int) or self.epochs <= 0:
  117. raise ValueError("epochs should be a positive integer")
  118. if self.nn_define and not isinstance(self.nn_define, dict):
  119. raise ValueError("bottom_nn_define should be a dict defining the structure of neural network")
  120. if self.batch_size != -1:
  121. if not isinstance(self.batch_size, int) \
  122. or self.batch_size < consts.MIN_BATCH_SIZE:
  123. raise ValueError(
  124. " {} not supported, should be larger than 10 or -1 represent for all data".format(self.batch_size))
  125. for p in deprecated_param_list:
  126. # if self._warn_to_deprecate_param(p, "", ""):
  127. if self._deprecated_params_set.get(p):
  128. if "callback_param" in self.get_user_feeded():
  129. raise ValueError(f"{p} and callback param should not be set simultaneously,"
  130. f"{self._deprecated_params_set}, {self.get_user_feeded()}")
  131. else:
  132. self.callback_param.callbacks = ["PerformanceEvaluate"]
  133. break
  134. descr = "ftl's"
  135. if self._warn_to_deprecate_param("validation_freqs", descr, "callback_param's 'validation_freqs'"):
  136. self.callback_param.validation_freqs = self.validation_freqs
  137. if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"):
  138. self.callback_param.metrics = self.metrics
  139. if self.validation_freqs is None:
  140. pass
  141. elif isinstance(self.validation_freqs, int):
  142. if self.validation_freqs < 1:
  143. raise ValueError("validation_freqs should be larger than 0 when it's integer")
  144. elif not isinstance(self.validation_freqs, collections.Container):
  145. raise ValueError("validation_freqs should be None or positive integer or container")
  146. assert isinstance(self.communication_efficient, bool), 'communication efficient must be a boolean'
  147. assert self.mode in [
  148. 'encrypted', 'plain'], 'mode options: encrpyted or plain, but {} is offered'.format(
  149. self.mode)
  150. self.check_positive_integer(self.epochs, 'epochs')
  151. self.check_positive_number(self.alpha, 'alpha')
  152. self.check_positive_integer(self.local_round, 'local round')
  153. @staticmethod
  154. def _parse_optimizer(opt):
  155. """
  156. Examples:
  157. 1. "optimize": "SGD"
  158. 2. "optimize": {
  159. "optimizer": "SGD",
  160. "learning_rate": 0.05
  161. }
  162. """
  163. kwargs = {}
  164. if isinstance(opt, str):
  165. return SimpleNamespace(optimizer=opt, kwargs=kwargs)
  166. elif isinstance(opt, dict):
  167. optimizer = opt.get("optimizer", kwargs)
  168. if not optimizer:
  169. raise ValueError(f"optimizer config: {opt} invalid")
  170. kwargs = {k: v for k, v in opt.items() if k != "optimizer"}
  171. return SimpleNamespace(optimizer=optimizer, kwargs=kwargs)
  172. else:
  173. raise ValueError(f"invalid type for optimize: {type(opt)}")