123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- #
- # Copyright 2019 The FATE Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import collections
- import copy
- from federatedml.param.intersect_param import IntersectParam
- from types import SimpleNamespace
- from federatedml.param.base_param import BaseParam, deprecated_param
- from federatedml.util import consts
- from federatedml.param.encrypt_param import EncryptParam
- from federatedml.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
- from federatedml.param.predict_param import PredictParam
- from federatedml.param.callback_param import CallbackParam
- deprecated_param_list = ["validation_freqs", "metrics"]
- @deprecated_param(*deprecated_param_list)
- class FTLParam(BaseParam):
- def __init__(self, alpha=1, tol=0.000001,
- n_iter_no_change=False, validation_freqs=None, optimizer={'optimizer': 'Adam', 'learning_rate': 0.01},
- nn_define={}, epochs=1, intersect_param=IntersectParam(consts.RSA), config_type='keras', batch_size=-1,
- encrypte_param=EncryptParam(),
- encrypted_mode_calculator_param=EncryptedModeCalculatorParam(mode="confusion_opt"),
- predict_param=PredictParam(), mode='plain', communication_efficient=False,
- local_round=5, callback_param=CallbackParam()):
- """
- Parameters
- ----------
- alpha : float
- a loss coefficient defined in paper, it defines the importance of alignment loss
- tol : float
- loss tolerance
- n_iter_no_change : bool
- check loss convergence or not
- validation_freqs : None or positive integer or container object in python
- Do validation in training process or Not.
- if equals None, will not do validation in train process;
- if equals positive integer, will validate data every validation_freqs epochs passes;
- if container object in python, will validate data if epochs belong to this container.
- e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
- The default value is None, 1 is suggested. You can set it to a number larger than 1 in order to
- speed up training by skipping validation rounds. When it is larger than 1, a number which is
- divisible by "epochs" is recommended, otherwise, you will miss the validation scores
- of last training epoch.
- optimizer : str or dict
- optimizer method, accept following types:
- 1. a string, one of "Adadelta", "Adagrad", "Adam", "Adamax", "Nadam", "RMSprop", "SGD"
- 2. a dict, with a required key-value pair keyed by "optimizer",
- with optional key-value pairs such as learning rate.
- defaults to "SGD"
- nn_define : dict
- a dict represents the structure of neural network, it can be output by tf-keras
- epochs : int
- epochs num
- intersect_param
- define the intersect method
- config_type : {'tf-keras'}
- config type
- batch_size : int
- batch size when computing transformed feature embedding, -1 use full data.
- encrypte_param
- encrypted param
- encrypted_mode_calculator_param
- encrypted mode calculator param:
- predict_param
- predict param
- mode: {"plain", "encrypted"}
- plain: will not use any encrypt algorithms, data exchanged in plaintext
- encrypted: use paillier to encrypt gradients
- communication_efficient: bool
- will use communication efficient or not. when communication efficient is enabled, FTL model will
- update gradients by several local rounds using intermediate data
- local_round: int
- local update round when using communication efficient
- """
- super(FTLParam, self).__init__()
- self.alpha = alpha
- self.tol = tol
- self.n_iter_no_change = n_iter_no_change
- self.validation_freqs = validation_freqs
- self.optimizer = optimizer
- self.nn_define = nn_define
- self.epochs = epochs
- self.intersect_param = copy.deepcopy(intersect_param)
- self.config_type = config_type
- self.batch_size = batch_size
- self.encrypted_mode_calculator_param = copy.deepcopy(encrypted_mode_calculator_param)
- self.encrypt_param = copy.deepcopy(encrypte_param)
- self.predict_param = copy.deepcopy(predict_param)
- self.mode = mode
- self.communication_efficient = communication_efficient
- self.local_round = local_round
- self.callback_param = copy.deepcopy(callback_param)
- def check(self):
- self.intersect_param.check()
- self.encrypt_param.check()
- self.encrypted_mode_calculator_param.check()
- self.optimizer = self._parse_optimizer(self.optimizer)
- supported_config_type = ["keras"]
- if self.config_type not in supported_config_type:
- raise ValueError(f"config_type should be one of {supported_config_type}")
- if not isinstance(self.tol, (int, float)):
- raise ValueError("tol should be numeric")
- if not isinstance(self.epochs, int) or self.epochs <= 0:
- raise ValueError("epochs should be a positive integer")
- if self.nn_define and not isinstance(self.nn_define, dict):
- raise ValueError("bottom_nn_define should be a dict defining the structure of neural network")
- if self.batch_size != -1:
- if not isinstance(self.batch_size, int) \
- or self.batch_size < consts.MIN_BATCH_SIZE:
- raise ValueError(
- " {} not supported, should be larger than 10 or -1 represent for all data".format(self.batch_size))
- for p in deprecated_param_list:
- # if self._warn_to_deprecate_param(p, "", ""):
- if self._deprecated_params_set.get(p):
- if "callback_param" in self.get_user_feeded():
- raise ValueError(f"{p} and callback param should not be set simultaneously,"
- f"{self._deprecated_params_set}, {self.get_user_feeded()}")
- else:
- self.callback_param.callbacks = ["PerformanceEvaluate"]
- break
- descr = "ftl's"
- if self._warn_to_deprecate_param("validation_freqs", descr, "callback_param's 'validation_freqs'"):
- self.callback_param.validation_freqs = self.validation_freqs
- if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"):
- self.callback_param.metrics = self.metrics
- if self.validation_freqs is None:
- pass
- elif isinstance(self.validation_freqs, int):
- if self.validation_freqs < 1:
- raise ValueError("validation_freqs should be larger than 0 when it's integer")
- elif not isinstance(self.validation_freqs, collections.Container):
- raise ValueError("validation_freqs should be None or positive integer or container")
- assert isinstance(self.communication_efficient, bool), 'communication efficient must be a boolean'
- assert self.mode in [
- 'encrypted', 'plain'], 'mode options: encrpyted or plain, but {} is offered'.format(
- self.mode)
- self.check_positive_integer(self.epochs, 'epochs')
- self.check_positive_number(self.alpha, 'alpha')
- self.check_positive_integer(self.local_round, 'local round')
- @staticmethod
- def _parse_optimizer(opt):
- """
- Examples:
- 1. "optimize": "SGD"
- 2. "optimize": {
- "optimizer": "SGD",
- "learning_rate": 0.05
- }
- """
- kwargs = {}
- if isinstance(opt, str):
- return SimpleNamespace(optimizer=opt, kwargs=kwargs)
- elif isinstance(opt, dict):
- optimizer = opt.get("optimizer", kwargs)
- if not optimizer:
- raise ValueError(f"optimizer config: {opt} invalid")
- kwargs = {k: v for k, v in opt.items() if k != "optimizer"}
- return SimpleNamespace(optimizer=optimizer, kwargs=kwargs)
- else:
- raise ValueError(f"invalid type for optimize: {type(opt)}")
|