ftl_base.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import copy
  2. import json
  3. import functools
  4. import numpy as np
  5. from federatedml.util import LOGGER
  6. from federatedml.transfer_learning.hetero_ftl.backend.nn_model import get_nn_builder
  7. from federatedml.model_base import ModelBase
  8. from federatedml.param.ftl_param import FTLParam
  9. from federatedml.transfer_learning.hetero_ftl.backend.tf_keras.nn_model import KerasNNModel
  10. from federatedml.util.classify_label_checker import ClassifyLabelChecker
  11. from federatedml.transfer_variable.transfer_class.ftl_transfer_variable import FTLTransferVariable
  12. from federatedml.transfer_learning.hetero_ftl.ftl_dataloder import FTLDataLoader
  13. from federatedml.transfer_learning.hetero_ftl.backend.tf_keras.data_generator import KerasSequenceDataConverter
  14. from federatedml.nn.backend.utils import rng as random_number_generator
  15. from federatedml.secureprotol import PaillierEncrypt
  16. from federatedml.util import consts
  17. from federatedml.secureprotol.paillier_tensor import PaillierTensor
  18. from federatedml.protobuf.generated.ftl_model_param_pb2 import FTLModelParam
  19. from federatedml.protobuf.generated.ftl_model_meta_pb2 import FTLModelMeta, FTLPredictParam, FTLOptimizerParam
  20. class FTL(ModelBase):
  21. def __init__(self):
  22. super(FTL, self).__init__()
  23. # input para
  24. self.nn_define = None
  25. self.alpha = None
  26. self.tol = None
  27. self.learning_rate = None
  28. self.n_iter_no_change = None
  29. self.validation_freqs = None
  30. self.early_stopping_rounds = None
  31. self.use_first_metric_only = None
  32. self.optimizer = None
  33. self.intersect_param = None
  34. self.config_type = 'keras'
  35. self.comm_eff = None
  36. self.local_round = 1
  37. # runtime variable
  38. self.verbose = False
  39. self.nn: KerasNNModel = None
  40. self.nn_builder = None
  41. self.model_param = FTLParam()
  42. self.x_shape = None
  43. self.input_dim = None
  44. self.data_num = 0
  45. self.overlap_num = 0
  46. self.transfer_variable = FTLTransferVariable()
  47. self.data_convertor = KerasSequenceDataConverter()
  48. self.mode = 'plain'
  49. self.encrypter = None
  50. self.partitions = 16
  51. self.batch_size = None
  52. self.epochs = None
  53. self.store_header = None # header of input data table
  54. self.model_float_type = np.float32
  55. self.cache_dataloader = {}
  56. self.validation_strategy = None
  57. def _init_model(self, param: FTLParam):
  58. self.nn_define = param.nn_define
  59. self.alpha = param.alpha
  60. self.tol = param.tol
  61. self.n_iter_no_change = param.n_iter_no_change
  62. self.validation_freqs = param.validation_freqs
  63. self.optimizer = param.optimizer
  64. self.intersect_param = param.intersect_param
  65. self.batch_size = param.batch_size
  66. self.epochs = param.epochs
  67. self.mode = param.mode
  68. self.comm_eff = param.communication_efficient
  69. self.local_round = param.local_round
  70. assert 'learning_rate' in self.optimizer.kwargs, 'optimizer setting must contain learning_rate'
  71. self.learning_rate = self.optimizer.kwargs['learning_rate']
  72. if not self.comm_eff:
  73. self.local_round = 1
  74. LOGGER.debug('communication efficient mode is not enabled, local_round set as 1')
  75. self.encrypter = self.generate_encrypter(param)
  76. self.predict_param = param.predict_param
  77. self.rng_generator = random_number_generator.RandomNumberGenerator()
  78. @staticmethod
  79. def debug_data_inst(data_inst):
  80. collect_data = list(data_inst.collect())
  81. LOGGER.debug('showing Table')
  82. for d in collect_data:
  83. LOGGER.debug('key {} id {}, features {} label {}'.format(d[0], d[1].inst_id, d[1].features, d[1].label))
  84. @staticmethod
  85. def reset_label(inst, mapping):
  86. new_inst = copy.deepcopy(inst)
  87. new_inst.label = mapping[new_inst.label]
  88. return new_inst
  89. @staticmethod
  90. def check_label(data_inst):
  91. """
  92. check label. FTL only supports binary classification, and labels should be 1 or -1
  93. """
  94. LOGGER.debug('checking label')
  95. label_checker = ClassifyLabelChecker()
  96. num_class, class_set = label_checker.validate_label(data_inst)
  97. if num_class != 2:
  98. raise ValueError(
  99. 'ftl only support binary classification, however {} labels are provided.'.format(num_class))
  100. if 1 in class_set and -1 in class_set:
  101. return data_inst
  102. else:
  103. soreted_class_set = sorted(list(class_set))
  104. new_label_mapping = {soreted_class_set[1]: 1, soreted_class_set[0]: -1}
  105. reset_label = functools.partial(FTL.reset_label, mapping=new_label_mapping)
  106. new_table = data_inst.mapValues(reset_label)
  107. new_table.schema = copy.deepcopy(data_inst.schema)
  108. return new_table
  109. def generate_encrypter(self, param) -> PaillierEncrypt:
  110. LOGGER.info("generate encrypter")
  111. if param.encrypt_param.method.lower() == consts.PAILLIER.lower():
  112. encrypter = PaillierEncrypt()
  113. encrypter.generate_key(param.encrypt_param.key_length)
  114. else:
  115. raise NotImplementedError("encrypt method not supported yet!!!")
  116. return encrypter
  117. def encrypt_tensor(self, components, return_dtable=True):
  118. """
  119. transform numpy array into Paillier tensor and encrypt
  120. """
  121. encrypted_tensors = []
  122. for comp in components:
  123. encrypted_tensor = PaillierTensor(comp, partitions=self.partitions)
  124. if return_dtable:
  125. encrypted_tensors.append(encrypted_tensor.encrypt(self.encrypter).get_obj())
  126. else:
  127. encrypted_tensors.append(encrypted_tensor.encrypt(self.encrypter))
  128. return encrypted_tensors
  129. def learning_rate_decay(self, learning_rate, epoch):
  130. """
  131. learning_rate decay
  132. """
  133. return learning_rate * 1 / np.sqrt(epoch + 1)
  134. def sync_stop_flag(self, num_round, stop_flag=None):
  135. """
  136. stop flag for n_iter_no_change
  137. """
  138. LOGGER.info("sync stop flag, boosting round is {}".format(num_round))
  139. if self.role == consts.GUEST:
  140. self.transfer_variable.stop_flag.remote(stop_flag,
  141. role=consts.HOST,
  142. idx=-1,
  143. suffix=(num_round,))
  144. elif self.role == consts.HOST:
  145. return self.transfer_variable.stop_flag.get(idx=0, suffix=(num_round, ))
  146. def prepare_data(self, intersect_obj, data_inst, guest_side=False):
  147. """
  148. find intersect ids and prepare dataloader
  149. """
  150. if guest_side:
  151. data_inst = self.check_label(data_inst)
  152. overlap_samples = intersect_obj.run_intersect(data_inst) # find intersect ids
  153. overlap_samples = intersect_obj.get_value_from_data(overlap_samples, data_inst)
  154. non_overlap_samples = data_inst.subtractByKey(overlap_samples)
  155. LOGGER.debug('num of overlap/non-overlap sampels: {}/{}'.format(overlap_samples.count(),
  156. non_overlap_samples.count()))
  157. if overlap_samples.count() == 0:
  158. raise ValueError('no overlap samples')
  159. if guest_side and non_overlap_samples == 0:
  160. raise ValueError('overlap samples are required in guest side')
  161. self.store_header = data_inst.schema['header']
  162. LOGGER.debug('data inst header is {}'.format(self.store_header))
  163. LOGGER.debug('has {} overlap samples'.format(overlap_samples.count()))
  164. batch_size = self.batch_size
  165. if self.batch_size == -1:
  166. batch_size = data_inst.count() + 1 # make sure larger than sample number
  167. data_loader = FTLDataLoader(non_overlap_samples=non_overlap_samples,
  168. batch_size=batch_size, overlap_samples=overlap_samples, guest_side=guest_side)
  169. LOGGER.debug("data details are :{}".format(data_loader.data_basic_info()))
  170. return data_loader, data_loader.x_shape, data_inst.count(), len(data_loader.get_overlap_indexes())
  171. def get_model_float_type(self, nn):
  172. weights = nn.get_trainable_weights()
  173. self.model_float_type = weights[0].dtype
  174. def initialize_nn(self, input_shape):
  175. """
  176. initializing nn weights
  177. """
  178. loss = "keep_predict_loss"
  179. self.nn_builder = get_nn_builder(config_type=self.config_type)
  180. self.nn = self.nn_builder(loss=loss, nn_define=self.nn_define, optimizer=self.optimizer, metrics=None,
  181. input_shape=input_shape)
  182. self.get_model_float_type(self.nn)
  183. LOGGER.debug('printing nn layers structure')
  184. for layer in self.nn._model.layers:
  185. LOGGER.debug('input shape {}, output shape {}'.format(layer.input_shape, layer.output_shape))
  186. def generate_mask(self, shape):
  187. """
  188. generate random number mask
  189. """
  190. return self.rng_generator.generate_random_number(shape)
  191. def _batch_gradient_update(self, X, grads):
  192. """
  193. compute and update gradients for all samples
  194. """
  195. data = self.data_convertor.convert_data(X, grads)
  196. self.nn.train(data)
  197. def _get_mini_batch_gradient(self, X_batch, backward_grads_batch):
  198. """
  199. compute gradient for a mini batch
  200. """
  201. X_batch = X_batch.astype(self.model_float_type)
  202. backward_grads_batch = backward_grads_batch.astype(self.model_float_type)
  203. grads = self.nn.get_weight_gradients(X_batch, backward_grads_batch)
  204. return grads
  205. def update_nn_weights(self, backward_grads, data_loader: FTLDataLoader, epoch_idx, decay=False):
  206. """
  207. updating bottom nn model weights using backward gradients
  208. """
  209. LOGGER.debug('updating grads at epoch {}'.format(epoch_idx))
  210. assert len(data_loader.x) == len(backward_grads)
  211. weight_grads = []
  212. for i in range(len(data_loader)):
  213. start, end = data_loader.get_batch_indexes(i)
  214. batch_x = data_loader.x[start: end]
  215. batch_grads = backward_grads[start: end]
  216. batch_weight_grads = self._get_mini_batch_gradient(batch_x, batch_grads)
  217. if len(weight_grads) == 0:
  218. weight_grads.extend(batch_weight_grads)
  219. else:
  220. for w, bw in zip(weight_grads, batch_weight_grads):
  221. w += bw
  222. if decay:
  223. new_learning_rate = self.learning_rate_decay(self.learning_rate, epoch_idx)
  224. self.nn.set_learning_rate(new_learning_rate)
  225. LOGGER.debug('epoch {} optimizer details are {}'.format(epoch_idx, self.nn.export_optimizer_config()))
  226. self.nn.apply_gradients(weight_grads)
  227. def export_nn(self):
  228. return self.nn.export_model()
  229. @staticmethod
  230. def get_dataset_key(data_inst):
  231. return id(data_inst)
  232. def get_model_meta(self):
  233. model_meta = FTLModelMeta()
  234. model_meta.config_type = self.config_type
  235. model_meta.nn_define = json.dumps(self.nn_define)
  236. model_meta.batch_size = self.batch_size
  237. model_meta.epochs = self.epochs
  238. model_meta.tol = self.tol
  239. model_meta.input_dim = self.input_dim
  240. predict_param = FTLPredictParam()
  241. optimizer_param = FTLOptimizerParam()
  242. optimizer_param.optimizer = self.optimizer.optimizer
  243. optimizer_param.kwargs = json.dumps(self.optimizer.kwargs)
  244. model_meta.optimizer_param.CopyFrom(optimizer_param)
  245. model_meta.predict_param.CopyFrom(predict_param)
  246. return model_meta
  247. def get_model_param(self):
  248. model_param = FTLModelParam()
  249. model_bytes = self.nn.export_model()
  250. model_param.model_bytes = model_bytes
  251. model_param.header.extend(list(self.store_header))
  252. return model_param
  253. def set_model_meta(self, model_meta):
  254. self.config_type = model_meta.config_type
  255. self.nn_define = json.loads(model_meta.nn_define)
  256. self.batch_size = model_meta.batch_size
  257. self.epochs = model_meta.epochs
  258. self.tol = model_meta.tol
  259. self.optimizer = FTLParam()._parse_optimizer(FTLParam().optimizer)
  260. self.input_dim = model_meta.input_dim
  261. self.optimizer.optimizer = model_meta.optimizer_param.optimizer
  262. self.optimizer.kwargs = json.loads(model_meta.optimizer_param.kwargs)
  263. self.initialize_nn((self.input_dim,))
  264. def set_model_param(self, model_param):
  265. self.nn.restore_model(model_param.model_bytes)
  266. self.store_header = list(model_param.header)
  267. LOGGER.debug('stored header load, is {}'.format(self.store_header))