ftl_param.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. #
  4. # Copyright 2019 The FATE Authors. All Rights Reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. #
  18. import collections
  19. import copy
  20. from pipeline.param.intersect_param import IntersectParam
  21. from types import SimpleNamespace
  22. from pipeline.param.base_param import BaseParam
  23. from pipeline.param import consts
  24. from pipeline.param.encrypt_param import EncryptParam
  25. from pipeline.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
  26. from pipeline.param.predict_param import PredictParam
  27. from pipeline.param.callback_param import CallbackParam
  28. class FTLParam(BaseParam):
  29. def __init__(self, alpha=1, tol=0.000001,
  30. n_iter_no_change=False, validation_freqs=None, optimizer={'optimizer': 'Adam', 'learning_rate': 0.01},
  31. nn_define={}, epochs=1, intersect_param=IntersectParam(consts.RSA), config_type='keras', batch_size=-1,
  32. encrypte_param=EncryptParam(),
  33. encrypted_mode_calculator_param=EncryptedModeCalculatorParam(mode="confusion_opt"),
  34. predict_param=PredictParam(), mode='plain', communication_efficient=False,
  35. local_round=5, callback_param=CallbackParam()):
  36. """
  37. Args:
  38. alpha: float, a loss coefficient defined in paper, it defines the importance of alignment loss
  39. tol: float, loss tolerance
  40. n_iter_no_change: bool, check loss convergence or not
  41. validation_freqs: None or positive integer or container object in python. Do validation in training process or Not.
  42. if equals None, will not do validation in train process;
  43. if equals positive integer, will validate data every validation_freqs epochs passes;
  44. if container object in python, will validate data if epochs belong to this container.
  45. e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
  46. Default: None
  47. The default value is None, 1 is suggested. You can set it to a number larger than 1 in order to
  48. speed up training by skipping validation rounds. When it is larger than 1, a number which is
  49. divisible by "epochs" is recommended, otherwise, you will miss the validation scores
  50. of last training epoch.
  51. optimizer: optimizer method, accept following types:
  52. 1. a string, one of "Adadelta", "Adagrad", "Adam", "Adamax", "Nadam", "RMSprop", "SGD"
  53. 2. a dict, with a required key-value pair keyed by "optimizer",
  54. with optional key-value pairs such as learning rate.
  55. defaults to "SGD"
  56. nn_define: dict, a dict represents the structure of neural network, it can be output by tf-keras
  57. epochs: int, epochs num
  58. intersect_param: define the intersect method
  59. config_type: now only 'tf-keras' is supported
  60. batch_size: batch size when computing transformed feature embedding, -1 use full data.
  61. encrypte_param: encrypted param
  62. encrypted_mode_calculator_param:
  63. predict_param: predict param
  64. mode:
  65. plain: will not use any encrypt algorithms, data exchanged in plaintext
  66. encrypted: use paillier to encrypt gradients
  67. communication_efficient:
  68. bool, will use communication efficient or not. when communication efficient is enabled, FTL model will
  69. update gradients by several local rounds using intermediate data
  70. local_round: local update round when using communication efficient
  71. """
  72. super(FTLParam, self).__init__()
  73. self.alpha = alpha
  74. self.tol = tol
  75. self.n_iter_no_change = n_iter_no_change
  76. self.validation_freqs = validation_freqs
  77. self.optimizer = optimizer
  78. self.nn_define = nn_define
  79. self.epochs = epochs
  80. self.intersect_param = copy.deepcopy(intersect_param)
  81. self.config_type = config_type
  82. self.batch_size = batch_size
  83. self.encrypted_mode_calculator_param = copy.deepcopy(encrypted_mode_calculator_param)
  84. self.encrypt_param = copy.deepcopy(encrypte_param)
  85. self.predict_param = copy.deepcopy(predict_param)
  86. self.mode = mode
  87. self.communication_efficient = communication_efficient
  88. self.local_round = local_round
  89. self.callback_param = copy.deepcopy(callback_param)
  90. def check(self):
  91. self.intersect_param.check()
  92. self.encrypt_param.check()
  93. self.encrypted_mode_calculator_param.check()
  94. self.optimizer = self._parse_optimizer(self.optimizer)
  95. supported_config_type = ["keras"]
  96. if self.config_type not in supported_config_type:
  97. raise ValueError(f"config_type should be one of {supported_config_type}")
  98. if not isinstance(self.tol, (int, float)):
  99. raise ValueError("tol should be numeric")
  100. if not isinstance(self.epochs, int) or self.epochs <= 0:
  101. raise ValueError("epochs should be a positive integer")
  102. if self.nn_define and not isinstance(self.nn_define, dict):
  103. raise ValueError("bottom_nn_define should be a dict defining the structure of neural network")
  104. if self.batch_size != -1:
  105. if not isinstance(self.batch_size, int) \
  106. or self.batch_size < consts.MIN_BATCH_SIZE:
  107. raise ValueError(
  108. " {} not supported, should be larger than 10 or -1 represent for all data".format(self.batch_size))
  109. if self.validation_freqs is None:
  110. pass
  111. elif isinstance(self.validation_freqs, int):
  112. if self.validation_freqs < 1:
  113. raise ValueError("validation_freqs should be larger than 0 when it's integer")
  114. elif not isinstance(self.validation_freqs, collections.Container):
  115. raise ValueError("validation_freqs should be None or positive integer or container")
  116. assert isinstance(self.communication_efficient, bool), 'communication efficient must be a boolean'
  117. assert self.mode in [
  118. 'encrypted', 'plain'], 'mode options: encrpyted or plain, but {} is offered'.format(
  119. self.mode)
  120. self.check_positive_integer(self.epochs, 'epochs')
  121. self.check_positive_number(self.alpha, 'alpha')
  122. self.check_positive_integer(self.local_round, 'local round')
  123. @staticmethod
  124. def _parse_optimizer(opt):
  125. """
  126. Examples:
  127. 1. "optimize": "SGD"
  128. 2. "optimize": {
  129. "optimizer": "SGD",
  130. "learning_rate": 0.05
  131. }
  132. """
  133. kwargs = {}
  134. if isinstance(opt, str):
  135. return SimpleNamespace(optimizer=opt, kwargs=kwargs)
  136. elif isinstance(opt, dict):
  137. optimizer = opt.get("optimizer", kwargs)
  138. if not optimizer:
  139. raise ValueError(f"optimizer config: {opt} invalid")
  140. kwargs = {k: v for k, v in opt.items() if k != "optimizer"}
  141. return SimpleNamespace(optimizer=optimizer, kwargs=kwargs)
  142. else:
  143. raise ValueError(f"invalid type for optimize: {type(opt)}")