homo_nn.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import copy
  17. import torch as t
  18. from torch.optim import Adam
  19. from pipeline.component.component_base import FateComponent
  20. from pipeline.component.nn.backend.torch.base import Sequential
  21. from pipeline.component.nn.backend.torch import base
  22. from pipeline.interface import Input
  23. from pipeline.interface import Output
  24. from pipeline.utils.tools import extract_explicit_parameter
  25. from pipeline.component.nn.interface import TrainerParam, DatasetParam
  26. from pipeline.component.nn.backend.torch.cust import CustModel
  27. from pipeline.utils.logger import LOGGER
  28. # default parameter dict
  29. DEFAULT_PARAM_DICT = {
  30. 'trainer': TrainerParam(trainer_name='fedavg_trainer'),
  31. 'dataset': DatasetParam(dataset_name='table'),
  32. 'torch_seed': 100,
  33. 'loss': None,
  34. 'optimizer': None,
  35. 'nn_define': None
  36. }
  37. class HomoNN(FateComponent):
  38. """
  39. Parameters
  40. ----------
  41. name, name of this component
  42. trainer, trainer param
  43. dataset, dataset param
  44. torch_seed, global random seed
  45. loss, loss function from fate_torch
  46. optimizer, optimizer from fate_torch
  47. model, a fate torch sequential defining the model structure
  48. """
  49. @extract_explicit_parameter
  50. def __init__(self,
  51. name=None,
  52. trainer: TrainerParam = TrainerParam(trainer_name='fedavg_trainer', epochs=10, batch_size=512, # training parameter
  53. early_stop=None, tol=0.0001, # early stop parameters
  54. secure_aggregate=True, weighted_aggregation=True,
  55. aggregate_every_n_epoch=None, # federation
  56. cuda=False, pin_memory=True, shuffle=True, data_loader_worker=0, # GPU dataloader
  57. validation_freqs=None),
  58. dataset: DatasetParam = DatasetParam(dataset_name='table'),
  59. torch_seed: int = 100,
  60. loss=None,
  61. optimizer: t.optim.Optimizer = None,
  62. model: Sequential = None, **kwargs):
  63. explicit_parameters = copy.deepcopy(DEFAULT_PARAM_DICT)
  64. if 'name' not in kwargs["explict_parameters"]:
  65. raise RuntimeError('moduel name is not set')
  66. explicit_parameters["name"] = kwargs["explict_parameters"]['name']
  67. FateComponent.__init__(self, **explicit_parameters)
  68. kwargs["explict_parameters"].pop('name')
  69. self.input = Input(self.name, data_type="multi")
  70. self.output = Output(self.name, data_type='single')
  71. self._module_name = "HomoNN"
  72. self._updated = {'trainer': False, 'dataset': False,
  73. 'torch_seed': False, 'loss': False, 'optimizer': False, 'model': False}
  74. self._set_param(kwargs["explict_parameters"])
  75. self._check_parameters()
  76. def _set_updated(self, attr, status=True):
  77. if attr in self._updated:
  78. self._updated[attr] = status
  79. else:
  80. raise ValueError('attr {} not in update status {}'.format(attr, self._updated))
  81. def _set_param(self, params):
  82. if "name" in params:
  83. del params["name"]
  84. for param_key, param_value in params.items():
  85. setattr(self, param_key, param_value)
  86. def _check_parameters(self):
  87. if hasattr(self, 'trainer') and self.trainer is not None and not self._updated['trainer']:
  88. assert isinstance(
  89. self.trainer, TrainerParam), 'trainer must be a TrainerPram class'
  90. self.trainer.check()
  91. self.trainer: TrainerParam = self.trainer.to_dict()
  92. self._set_updated('trainer', True)
  93. if hasattr(self, 'dataset') and self.dataset is not None and not self._updated['dataset']:
  94. assert isinstance(
  95. self.dataset, DatasetParam), 'dataset must be a DatasetParam class'
  96. self.dataset.check()
  97. self.dataset: DatasetParam = self.dataset.to_dict()
  98. self._set_updated('dataset', True)
  99. if hasattr(self, 'model') and self.model is not None and not self._updated['model']:
  100. if isinstance(self.model, Sequential):
  101. self.nn_define = self.model.get_network_config()
  102. elif isinstance(self.model, CustModel):
  103. self.model = Sequential(self.model)
  104. self.nn_define = self.model.get_network_config()
  105. else:
  106. raise RuntimeError('Model must be a fate-torch Sequential, but got {} '
  107. '\n do remember to call fate_torch_hook():'
  108. '\n import torch as t'
  109. '\n fate_torch_hook(t)'.format(
  110. type(self.model)))
  111. self._set_updated('model', True)
  112. if hasattr(self, 'optimizer') and self.optimizer is not None and not self._updated['optimizer']:
  113. if not isinstance(self.optimizer, base.FateTorchOptimizer):
  114. raise ValueError('please pass FateTorchOptimizer instances to Homo-nn components, got {}.'
  115. 'do remember to use fate_torch_hook():\n'
  116. ' import torch as t\n'
  117. ' fate_torch_hook(t)'.format(type(self.optimizer)))
  118. optimizer_config = self.optimizer.to_dict()
  119. self.optimizer = optimizer_config
  120. self._set_updated('optimizer', True)
  121. if hasattr(self, 'loss') and self.loss is not None and not self._updated['loss']:
  122. if isinstance(self.loss, base.FateTorchLoss):
  123. loss_config = self.loss.to_dict()
  124. elif issubclass(self.loss, base.FateTorchLoss):
  125. loss_config = self.loss().to_dict()
  126. else:
  127. raise ValueError('unable to parse loss function {}, loss must be an instance'
  128. 'of FateTorchLoss subclass or a subclass of FateTorchLoss, '
  129. 'do remember to use fate_torch_hook()'.format(self.loss))
  130. self.loss = loss_config
  131. self._set_updated('loss', True)
  132. def component_param(self, **kwargs):
  133. # reset paramerters
  134. used_attr = set()
  135. setattr(self, 'model', None)
  136. if 'model' in kwargs:
  137. self.model = kwargs['model']
  138. kwargs.pop('model')
  139. self._set_updated('model', False)
  140. for attr in self._component_parameter_keywords:
  141. if attr in kwargs:
  142. setattr(self, attr, kwargs[attr])
  143. self._set_updated(attr, False)
  144. used_attr.add(attr)
  145. self._check_parameters() # check and convert homo-nn paramters
  146. not_use_attr = set(kwargs.keys()).difference(used_attr)
  147. for attr in not_use_attr:
  148. LOGGER.warning(f"key {attr}, value {kwargs[attr]} not use")
  149. self._role_parameter_keywords |= used_attr
  150. for attr in self.__dict__:
  151. if attr not in self._component_parameter_keywords:
  152. continue
  153. else:
  154. self._component_param[attr] = getattr(self, attr)
  155. def __getstate__(self):
  156. state = dict(self.__dict__)
  157. if "model" in state:
  158. del state["model"]
  159. return state