hetero_ftl.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from pipeline.component.component_base import FateComponent
  17. from pipeline.component.nn.models.sequantial import Sequential
  18. from pipeline.interface import Input
  19. from pipeline.interface import Output
  20. from pipeline.utils.tools import extract_explicit_parameter
  21. from pipeline.param import consts
  22. class HeteroFTL(FateComponent):
  23. @extract_explicit_parameter
  24. def __init__(self, epochs=1, batch_size=-1,
  25. encrypt_param=None, predict_param=None, cv_param=None,
  26. intersect_param={'intersect_method': consts.RSA},
  27. validation_freqs=None, early_stopping_rounds=None, use_first_metric_only=None,
  28. mode='plain', communication_efficient=False, n_iter_no_change=False, tol=1e-5,
  29. local_round=5,
  30. **kwargs):
  31. explicit_parameters = kwargs["explict_parameters"]
  32. explicit_parameters["optimizer"] = None
  33. # explicit_parameters["loss"] = None
  34. # explicit_parameters["metrics"] = None
  35. explicit_parameters["nn_define"] = None
  36. explicit_parameters["config_type"] = "keras"
  37. FateComponent.__init__(self, **explicit_parameters)
  38. if "name" in explicit_parameters:
  39. del explicit_parameters["name"]
  40. for param_key, param_value in explicit_parameters.items():
  41. setattr(self, param_key, param_value)
  42. self.input = Input(self.name, data_type="multi")
  43. self.output = Output(self.name, data_type='single')
  44. self._module_name = "FTL"
  45. self.optimizer = None
  46. self.loss = None
  47. self.config_type = "keras"
  48. self.metrics = None
  49. self.bottom_nn_define = None
  50. self.top_nn_define = None
  51. self.interactive_layer_define = None
  52. self._nn_model = Sequential()
  53. self.nn_define = None
  54. def add_nn_layer(self, layer):
  55. self._nn_model.add(layer)
  56. def compile(self, optimizer,):
  57. self.optimizer = self._nn_model.get_optimizer_config(optimizer)
  58. self.config_type = self._nn_model.get_layer_type()
  59. self.nn_define = self._nn_model.get_network_config()
  60. def __getstate__(self):
  61. state = dict(self.__dict__)
  62. del state["_nn_model"]
  63. return state