ftl_host.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import numpy as np
  2. from federatedml.transfer_learning.hetero_ftl.ftl_base import FTL
  3. from federatedml.statistic.intersect import RsaIntersectionHost
  4. from federatedml.util import LOGGER
  5. from federatedml.transfer_learning.hetero_ftl.ftl_dataloder import FTLDataLoader
  6. from federatedml.util import consts
  7. from federatedml.secureprotol.paillier_tensor import PaillierTensor
  8. from federatedml.util.io_check import assert_io_num_rows_equal
  9. from federatedml.statistic import data_overview
  10. class FTLHost(FTL):
  11. def __init__(self):
  12. super(FTLHost, self).__init__()
  13. self.overlap_ub = None # u_b
  14. self.overlap_ub_2 = None # u_b squared
  15. self.mapping_comp_b = None
  16. self.constant_k = None # κ
  17. self.feat_dim = None # output feature dimension
  18. self.m_b = None # random mask
  19. self.role = consts.HOST
  20. def init_intersect_obj(self):
  21. LOGGER.debug('creating intersect obj done')
  22. intersect_obj = RsaIntersectionHost()
  23. intersect_obj.host_party_id = self.component_properties.local_partyid
  24. intersect_obj.host_party_id_list = self.component_properties.host_party_idlist
  25. intersect_obj.load_params(self.intersect_param)
  26. return intersect_obj
  27. def batch_compute_components(self, data_loader: FTLDataLoader):
  28. """
  29. compute host components
  30. """
  31. overlap_ub = []
  32. for i in range(len(data_loader)):
  33. batch_x = data_loader[i]
  34. ub_batch = self.nn.predict(batch_x)
  35. overlap_ub.append(ub_batch)
  36. overlap_ub = np.concatenate(overlap_ub, axis=0)
  37. overlap_ub_2 = np.matmul(np.expand_dims(overlap_ub, axis=2), np.expand_dims(overlap_ub, axis=1))
  38. mapping_comp_b = - overlap_ub * self.constant_k
  39. if self.verbose:
  40. LOGGER.debug('overlap_ub is {}'.format(overlap_ub))
  41. LOGGER.debug('overlap_ub_2 is {}'.format(overlap_ub_2))
  42. return overlap_ub, overlap_ub_2, mapping_comp_b
  43. def exchange_components(self, comp_to_send, epoch_idx):
  44. """
  45. compute host components and sent to guest
  46. """
  47. if self.mode == 'encrypted':
  48. comp_to_send = self.encrypt_tensor(comp_to_send)
  49. # receiving guest components
  50. y_overlap_2_phi_2 = self.transfer_variable.y_overlap_2_phi_2.get(idx=0, suffix=(epoch_idx, ))
  51. y_overlap_phi = self.transfer_variable.y_overlap_phi.get(idx=0, suffix=(epoch_idx, ))
  52. mapping_comp_a = self.transfer_variable.mapping_comp_a.get(idx=0, suffix=(epoch_idx, ))
  53. guest_components = [y_overlap_2_phi_2, y_overlap_phi, mapping_comp_a]
  54. # sending host components
  55. self.transfer_variable.overlap_ub.remote(comp_to_send[0], suffix=(epoch_idx, ))
  56. self.transfer_variable.overlap_ub_2.remote(comp_to_send[1], suffix=(epoch_idx, ))
  57. self.transfer_variable.mapping_comp_b.remote(comp_to_send[2], suffix=(epoch_idx, ))
  58. if self.mode == 'encrypted':
  59. guest_paillier_tensors = [PaillierTensor(tb, partitions=self.partitions) for tb in guest_components]
  60. return guest_paillier_tensors
  61. else:
  62. return guest_components
  63. def decrypt_guest_data(self, epoch_idx, local_round=-1):
  64. encrypted_consts = self.transfer_variable.guest_side_const.get(suffix=(epoch_idx, local_round, ),
  65. idx=0)
  66. grad_table = self.transfer_variable.guest_side_gradients.get(suffix=(epoch_idx, local_round, ),
  67. idx=0)
  68. inter_grad = PaillierTensor(grad_table, partitions=self.partitions)
  69. decrpyted_grad = inter_grad.decrypt(self.encrypter)
  70. decrypted_const = self.encrypter.recursive_decrypt(encrypted_consts)
  71. self.transfer_variable.decrypted_guest_const.remote(decrypted_const,
  72. suffix=(epoch_idx, local_round, ))
  73. self.transfer_variable.decrypted_guest_gradients.remote(decrpyted_grad.get_obj(),
  74. suffix=(epoch_idx, local_round, ))
  75. def decrypt_inter_result(self, loss_grad_b, epoch_idx, local_round=-1):
  76. rand_0 = PaillierTensor(
  77. self.rng_generator.generate_random_number(
  78. loss_grad_b.shape),
  79. partitions=self.partitions)
  80. grad_a_overlap = loss_grad_b + rand_0
  81. self.transfer_variable.host_side_gradients.remote(grad_a_overlap.get_obj(),
  82. suffix=(epoch_idx, local_round, 'host_de_send'))
  83. de_loss_grad_b = self.transfer_variable.decrypted_host_gradients\
  84. .get(suffix=(epoch_idx, local_round, 'host_de_get'), idx=0)
  85. de_loss_grad_b = PaillierTensor(de_loss_grad_b, partitions=self.partitions) - rand_0
  86. return de_loss_grad_b
  87. def compute_backward_gradients(self, guest_components, data_loader: FTLDataLoader, epoch_idx, local_round=-1):
  88. """
  89. compute host bottom model gradients
  90. """
  91. y_overlap_2_phi_2, y_overlap_phi, mapping_comp_a = guest_components[0], guest_components[1], guest_components[2]
  92. ub_overlap_ex = np.expand_dims(self.overlap_ub, axis=1)
  93. if self.mode == 'plain':
  94. ub_overlap_y_overlap_2_phi_2 = np.matmul(ub_overlap_ex, y_overlap_2_phi_2)
  95. l1_grad_b = np.squeeze(ub_overlap_y_overlap_2_phi_2, axis=1) + y_overlap_phi
  96. loss_grad_b = self.alpha * l1_grad_b + mapping_comp_a
  97. return loss_grad_b
  98. if self.mode == 'encrypted':
  99. ub_overlap_ex = np.expand_dims(self.overlap_ub, axis=1)
  100. ub_overlap_y_overlap_2_phi_2 = y_overlap_2_phi_2.matmul_3d(ub_overlap_ex, multiply='right')
  101. ub_overlap_y_overlap_2_phi_2 = ub_overlap_y_overlap_2_phi_2.squeeze(axis=1)
  102. l1_grad_b = ub_overlap_y_overlap_2_phi_2 + y_overlap_phi
  103. en_loss_grad_b = l1_grad_b * self.alpha + mapping_comp_a
  104. self.decrypt_guest_data(epoch_idx, local_round=local_round)
  105. loss_grad_b = self.decrypt_inter_result(en_loss_grad_b, epoch_idx, local_round=local_round)
  106. return loss_grad_b.numpy()
  107. def compute_loss(self, epoch_idx):
  108. """
  109. help guest compute ftl loss. plain mode will skip/ in encrypted mode will decrypt received loss
  110. """
  111. if self.mode == 'plain':
  112. return
  113. elif self.mode == 'encrypted':
  114. encrypted_loss = self.transfer_variable.encrypted_loss.get(idx=0, suffix=(epoch_idx, 'send_loss'))
  115. rs = self.encrypter.recursive_decrypt(encrypted_loss)
  116. self.transfer_variable.decrypted_loss.remote(rs, suffix=(epoch_idx, 'get_loss'))
  117. def fit(self, data_inst, validate_data=None):
  118. LOGGER.info('start to fit a ftl model, '
  119. 'run mode is {},'
  120. 'communication efficient mode is {}'.format(self.mode, self.comm_eff))
  121. data_loader, self.x_shape, self.data_num, self.overlap_num = self.prepare_data(self.init_intersect_obj(),
  122. data_inst, guest_side=False)
  123. self.input_dim = self.x_shape[0]
  124. # cache data_loader for faster validation
  125. self.cache_dataloader[self.get_dataset_key(data_inst)] = data_loader
  126. self.partitions = data_inst.partitions
  127. self.initialize_nn(input_shape=self.x_shape)
  128. self.feat_dim = self.nn._model.output_shape[1]
  129. self.constant_k = 1 / self.feat_dim
  130. self.callback_list.on_train_begin(data_inst, validate_data)
  131. for epoch_idx in range(self.epochs):
  132. LOGGER.debug('fitting epoch {}'.format(epoch_idx))
  133. self.callback_list.on_epoch_begin(epoch_idx)
  134. self.overlap_ub, self.overlap_ub_2, self.mapping_comp_b = self.batch_compute_components(data_loader)
  135. send_components = [self.overlap_ub, self.overlap_ub_2, self.mapping_comp_b]
  136. guest_components = self.exchange_components(send_components, epoch_idx)
  137. for local_round_idx in range(self.local_round):
  138. if self.comm_eff:
  139. LOGGER.debug('running local iter {}'.format(local_round_idx))
  140. grads = self.compute_backward_gradients(guest_components, data_loader, epoch_idx,
  141. local_round=local_round_idx)
  142. self.update_nn_weights(grads, data_loader, epoch_idx, decay=self.comm_eff)
  143. if local_round_idx == 0:
  144. self.compute_loss(epoch_idx)
  145. if local_round_idx + 1 != self.local_round:
  146. self.overlap_ub, self.overlap_ub_2, self.mapping_comp_b = self.batch_compute_components(data_loader)
  147. self.callback_list.on_epoch_end(epoch_idx)
  148. if self.n_iter_no_change is True:
  149. stop_flag = self.sync_stop_flag(epoch_idx)
  150. if stop_flag:
  151. break
  152. LOGGER.debug('fitting epoch {} done'.format(epoch_idx))
  153. self.callback_list.on_train_end()
  154. self.set_summary(self.generate_summary())
  155. def generate_summary(self):
  156. summary = {"best_iteration": self.callback_variables.best_iteration}
  157. return summary
  158. @assert_io_num_rows_equal
  159. def predict(self, data_inst):
  160. LOGGER.debug('host start to predict')
  161. self.transfer_variable.predict_host_u.disable_auto_clean()
  162. data_loader_key = self.get_dataset_key(data_inst)
  163. data_inst_ = data_overview.header_alignment(data_inst, self.store_header)
  164. if data_loader_key in self.cache_dataloader:
  165. data_loader = self.cache_dataloader[data_loader_key]
  166. else:
  167. data_loader, _, _, _ = self.prepare_data(self.init_intersect_obj(), data_inst_, guest_side=False)
  168. self.cache_dataloader[data_loader_key] = data_loader
  169. ub_batches = []
  170. for i in range(len(data_loader)):
  171. batch_x = data_loader[i]
  172. ub_batch = self.nn.predict(batch_x)
  173. ub_batches.append(ub_batch)
  174. predicts = np.concatenate(ub_batches, axis=0)
  175. self.transfer_variable.predict_host_u.remote(predicts, suffix=(0, 'host_u'))
  176. LOGGER.debug('ftl host prediction done')
  177. return None
  178. def export_model(self):
  179. return {"FTLHostMeta": self.get_model_meta(), "FTLHostParam": self.get_model_param()}
  180. def load_model(self, model_dict):
  181. model_param = None
  182. model_meta = None
  183. for _, value in model_dict["model"].items():
  184. for model in value:
  185. if model.endswith("Meta"):
  186. model_meta = value[model]
  187. if model.endswith("Param"):
  188. model_param = value[model]
  189. LOGGER.info("load model")
  190. self.set_model_meta(model_meta)
  191. self.set_model_param(model_param)