ftl_guest.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. import numpy as np
  2. from fate_arch.session import computing_session as session
  3. from federatedml.util import consts
  4. from federatedml.transfer_learning.hetero_ftl.ftl_base import FTL
  5. from federatedml.util import LOGGER
  6. from federatedml.transfer_learning.hetero_ftl.ftl_dataloder import FTLDataLoader
  7. from federatedml.statistic.intersect import RsaIntersectionGuest
  8. from federatedml.model_base import Metric
  9. from federatedml.model_base import MetricMeta
  10. from federatedml.optim.convergence import converge_func_factory
  11. from federatedml.secureprotol.paillier_tensor import PaillierTensor
  12. from federatedml.optim.activation import sigmoid
  13. from federatedml.statistic import data_overview
  14. class FTLGuest(FTL):
  15. def __init__(self):
  16. super(FTLGuest, self).__init__()
  17. self.phi = None # Φ_A
  18. self.phi_product = None # (Φ_A)‘(Φ_A) [feature_dim, feature_dim]
  19. self.overlap_y = None # y_i ∈ N_c
  20. self.overlap_y_2 = None # (y_i ∈ N_c )^2
  21. self.overlap_ua = None # u_i ∈ N_AB
  22. self.constant_k = None # κ
  23. self.feat_dim = None # output feature dimension
  24. self.send_components = None # components to send
  25. self.convergence = None
  26. self.overlap_y_pt = None # paillier tensor
  27. self.history_loss = [] # list to record history loss
  28. self.role = consts.GUEST
  29. def init_intersect_obj(self):
  30. intersect_obj = RsaIntersectionGuest()
  31. intersect_obj.guest_party_id = self.component_properties.local_partyid
  32. intersect_obj.host_party_id_list = self.component_properties.host_party_idlist
  33. intersect_obj.load_params(self.intersect_param)
  34. LOGGER.debug('intersect done')
  35. return intersect_obj
  36. def check_convergence(self, loss):
  37. LOGGER.info("check convergence")
  38. if self.convergence is None:
  39. self.convergence = converge_func_factory("diff", self.tol)
  40. return self.convergence.is_converge(loss)
  41. def compute_phi_and_overlap_ua(self, data_loader: FTLDataLoader):
  42. """
  43. compute Φ and ua of overlap samples
  44. """
  45. phi = None # [1, feature_dim] Φ_A
  46. overlap_ua = []
  47. for i in range(len(data_loader)):
  48. batch_x, batch_y = data_loader[i]
  49. ua_batch = self.nn.predict(batch_x) # [batch_size, feature_dim]
  50. relative_overlap_index = data_loader.get_relative_overlap_index(i)
  51. if len(relative_overlap_index) != 0:
  52. if self.verbose:
  53. LOGGER.debug('batch {}/{} overlap index is {}'.format(i, len(data_loader), relative_overlap_index))
  54. overlap_ua.append(ua_batch[relative_overlap_index])
  55. phi_tmp = np.expand_dims(np.sum(batch_y * ua_batch, axis=0), axis=0)
  56. if phi is None:
  57. phi = phi_tmp
  58. else:
  59. phi += phi_tmp
  60. phi = phi / self.data_num
  61. return phi, overlap_ua
  62. def batch_compute_components(self, data_loader: FTLDataLoader):
  63. """
  64. compute guest components
  65. """
  66. phi, overlap_ua = self.compute_phi_and_overlap_ua(data_loader) # Φ_A [1, feature_dim]
  67. phi_product = np.matmul(phi.transpose(), phi) # (Φ_A)‘(Φ_A) [feature_dim, feature_dim]
  68. if self.overlap_y is None:
  69. self.overlap_y = data_loader.get_overlap_y() # {C(y)=y} [1, feat_dim]
  70. if self.overlap_y_2 is None:
  71. self.overlap_y_2 = self.overlap_y * self.overlap_y # {D(y)=y^2} # [1, feat_dim]
  72. overlap_ua = np.concatenate(overlap_ua, axis=0) # [overlap_num, feat_dim]
  73. # 3 components will be sent to host
  74. y_overlap_2_phi_2 = 0.25 * np.expand_dims(self.overlap_y_2, axis=2) * phi_product
  75. y_overlap_phi = -0.5 * self.overlap_y * phi
  76. mapping_comp_a = -overlap_ua * self.constant_k
  77. return phi, phi_product, overlap_ua, [y_overlap_2_phi_2, y_overlap_phi, mapping_comp_a]
  78. def exchange_components(self, comp_to_send, epoch_idx):
  79. """
  80. send guest components and get host components
  81. """
  82. if self.mode == 'encrypted':
  83. comp_to_send = self.encrypt_tensor(comp_to_send)
  84. # sending [y_overlap_2_phi_2, y_overlap_phi, mapping_comp_a]
  85. self.transfer_variable.y_overlap_2_phi_2.remote(comp_to_send[0], suffix=(epoch_idx, ))
  86. self.transfer_variable.y_overlap_phi.remote(comp_to_send[1], suffix=(epoch_idx, ))
  87. self.transfer_variable.mapping_comp_a.remote(comp_to_send[2], suffix=(epoch_idx, ))
  88. # receiving [overlap_ub, overlap_ub_2, mapping_comp_b]
  89. overlap_ub = self.transfer_variable.overlap_ub.get(idx=0, suffix=(epoch_idx, ))
  90. overlap_ub_2 = self.transfer_variable.overlap_ub_2.get(idx=0, suffix=(epoch_idx, ))
  91. mapping_comp_b = self.transfer_variable.mapping_comp_b.get(idx=0, suffix=(epoch_idx, ))
  92. host_components = [overlap_ub, overlap_ub_2, mapping_comp_b]
  93. if self.mode == 'encrypted':
  94. host_paillier_tensors = [PaillierTensor(tb, partitions=self.partitions) for tb in host_components]
  95. return host_paillier_tensors
  96. else:
  97. return host_components
  98. def decrypt_inter_result(self, encrypted_const, grad_a_overlap, epoch_idx, local_round=-1):
  99. """
  100. add random mask to encrypted inter-result, get decrypted data from host add subtract random mask
  101. """
  102. rand_0 = self.rng_generator.generate_random_number(encrypted_const.shape)
  103. encrypted_const = encrypted_const + rand_0
  104. rand_1 = PaillierTensor(
  105. self.rng_generator.generate_random_number(
  106. grad_a_overlap.shape),
  107. partitions=self.partitions)
  108. grad_a_overlap = grad_a_overlap + rand_1
  109. self.transfer_variable.guest_side_const.remote(encrypted_const, suffix=(epoch_idx,
  110. local_round,))
  111. self.transfer_variable.guest_side_gradients.remote(grad_a_overlap.get_obj(), suffix=(epoch_idx,
  112. local_round,))
  113. const = self.transfer_variable.decrypted_guest_const.get(suffix=(epoch_idx, local_round, ), idx=0)
  114. grad = self.transfer_variable.decrypted_guest_gradients.get(suffix=(epoch_idx, local_round, ), idx=0)
  115. const = const - rand_0
  116. grad_a_overlap = PaillierTensor(grad, partitions=self.partitions) - rand_1
  117. return const, grad_a_overlap
  118. def decrypt_host_data(self, epoch_idx, local_round=-1):
  119. inter_grad = self.transfer_variable.host_side_gradients.get(suffix=(epoch_idx,
  120. local_round,
  121. 'host_de_send'), idx=0)
  122. inter_grad_pt = PaillierTensor(inter_grad, partitions=self.partitions)
  123. self.transfer_variable.decrypted_host_gradients.remote(inter_grad_pt.decrypt(self.encrypter).get_obj(),
  124. suffix=(epoch_idx,
  125. local_round,
  126. 'host_de_get'))
  127. def decrypt_loss_val(self, encrypted_loss, epoch_idx):
  128. self.transfer_variable.encrypted_loss.remote(encrypted_loss, suffix=(epoch_idx, 'send_loss'))
  129. decrypted_loss = self.transfer_variable.decrypted_loss.get(idx=0, suffix=(epoch_idx, 'get_loss'))
  130. return decrypted_loss
  131. def compute_backward_gradients(self, host_components, data_loader: FTLDataLoader, epoch_idx, local_round=-1):
  132. """
  133. compute backward gradients using host components
  134. """
  135. # they are Paillier tensors or np array
  136. overlap_ub, overlap_ub_2, mapping_comp_b = host_components[0], host_components[1], host_components[2]
  137. y_overlap_2_phi = np.expand_dims(self.overlap_y_2 * self.phi, axis=1)
  138. if self.mode == 'plain':
  139. loss_grads_const_part1 = 0.25 * np.squeeze(np.matmul(y_overlap_2_phi, overlap_ub_2), axis=1)
  140. loss_grads_const_part2 = self.overlap_y * overlap_ub
  141. const = np.sum(loss_grads_const_part1, axis=0) - 0.5 * np.sum(loss_grads_const_part2, axis=0)
  142. grad_a_nonoverlap = self.alpha * const * \
  143. data_loader.y[data_loader.get_non_overlap_indexes()] / self.data_num
  144. grad_a_overlap = self.alpha * const * self.overlap_y / self.data_num + mapping_comp_b
  145. return np.concatenate([grad_a_overlap, grad_a_nonoverlap], axis=0)
  146. elif self.mode == 'encrypted':
  147. loss_grads_const_part1 = overlap_ub_2.matmul_3d(0.25 * y_overlap_2_phi, multiply='right')
  148. loss_grads_const_part1 = loss_grads_const_part1.squeeze(axis=1)
  149. if self.overlap_y_pt is None:
  150. self.overlap_y_pt = PaillierTensor(self.overlap_y, partitions=self.partitions)
  151. loss_grads_const_part2 = overlap_ub * self.overlap_y_pt
  152. encrypted_const = loss_grads_const_part1.reduce_sum() - 0.5 * loss_grads_const_part2.reduce_sum()
  153. grad_a_overlap = self.overlap_y_pt.map_ndarray_product(
  154. (self.alpha / self.data_num * encrypted_const)) + mapping_comp_b
  155. const, grad_a_overlap = self.decrypt_inter_result(
  156. encrypted_const, grad_a_overlap, epoch_idx=epoch_idx, local_round=local_round)
  157. self.decrypt_host_data(epoch_idx, local_round=local_round)
  158. grad_a_nonoverlap = self.alpha * const * \
  159. data_loader.y[data_loader.get_non_overlap_indexes()] / self.data_num
  160. return np.concatenate([grad_a_overlap.numpy(), grad_a_nonoverlap], axis=0)
  161. def compute_loss(self, host_components, epoch_idx, overlap_num):
  162. """
  163. compute training loss
  164. """
  165. overlap_ub, overlap_ub_2, mapping_comp_b = host_components[0], host_components[1], host_components[2]
  166. if self.mode == 'plain':
  167. loss_overlap = np.sum((-self.overlap_ua * self.constant_k) * overlap_ub)
  168. ub_phi = np.matmul(overlap_ub, self.phi.transpose())
  169. part1 = -0.5 * np.sum(self.overlap_y * ub_phi)
  170. part2 = 1.0 / 8 * np.sum(ub_phi * ub_phi)
  171. part3 = len(self.overlap_y) * np.log(2)
  172. loss_y = part1 + part2 + part3
  173. return self.alpha * (loss_y / overlap_num) + loss_overlap / overlap_num
  174. elif self.mode == 'encrypted':
  175. loss_overlap = overlap_ub.element_wise_product((-self.overlap_ua * self.constant_k))
  176. sum = np.sum(loss_overlap.reduce_sum())
  177. ub_phi = overlap_ub.T.fast_matmul_2d(self.phi.transpose())
  178. part1 = -0.5 * np.sum((self.overlap_y * ub_phi))
  179. ub_2 = overlap_ub_2.reduce_sum()
  180. enc_phi_uB_2_phi = np.matmul(np.matmul(self.phi, ub_2), self.phi.transpose())
  181. part2 = 1 / 8 * np.sum(enc_phi_uB_2_phi)
  182. part3 = len(self.overlap_y) * np.log(2)
  183. loss_y = part1 + part2 + part3
  184. en_loss = (self.alpha / self.overlap_num) * loss_y + sum / overlap_num
  185. loss_val = self.decrypt_loss_val(en_loss, epoch_idx)
  186. return loss_val
  187. @staticmethod
  188. def sigmoid(x):
  189. return np.array(list(map(sigmoid, x)))
  190. def generate_summary(self):
  191. summary = {'loss_history': self.history_loss,
  192. "best_iteration": self.callback_variables.best_iteration}
  193. summary['validation_metrics'] = self.callback_variables.validation_summary
  194. return summary
  195. def check_host_number(self):
  196. host_num = len(self.component_properties.host_party_idlist)
  197. LOGGER.info('host number is {}'.format(host_num))
  198. if host_num != 1:
  199. raise ValueError('only 1 host party is allowed')
  200. def fit(self, data_inst, validate_data=None):
  201. LOGGER.debug('in training, partitions is {}'.format(data_inst.partitions))
  202. LOGGER.info('start to fit a ftl model, '
  203. 'run mode is {},'
  204. 'communication efficient mode is {}'.format(self.mode, self.comm_eff))
  205. self.check_host_number()
  206. data_loader, self.x_shape, self.data_num, self.overlap_num = self.prepare_data(self.init_intersect_obj(),
  207. data_inst, guest_side=True)
  208. self.input_dim = self.x_shape[0]
  209. # cache data_loader for faster validation
  210. self.cache_dataloader[self.get_dataset_key(data_inst)] = data_loader
  211. self.partitions = data_inst.partitions
  212. LOGGER.debug('self partitions is {}'.format(self.partitions))
  213. self.initialize_nn(input_shape=self.x_shape)
  214. self.feat_dim = self.nn._model.output_shape[1]
  215. self.constant_k = 1 / self.feat_dim
  216. self.callback_list.on_train_begin(train_data=data_inst, validate_data=validate_data)
  217. self.callback_meta("loss",
  218. "train",
  219. MetricMeta(name="train",
  220. metric_type="LOSS",
  221. extra_metas={"unit_name": "iters"}))
  222. # compute intermediate result of first epoch
  223. self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components(data_loader)
  224. for epoch_idx in range(self.epochs):
  225. LOGGER.debug('fitting epoch {}'.format(epoch_idx))
  226. self.callback_list.on_epoch_begin(epoch_idx)
  227. host_components = self.exchange_components(self.send_components, epoch_idx=epoch_idx)
  228. loss = None
  229. for local_round_idx in range(self.local_round):
  230. if self.comm_eff:
  231. LOGGER.debug('running local iter {}'.format(local_round_idx))
  232. grads = self.compute_backward_gradients(host_components, data_loader, epoch_idx=epoch_idx,
  233. local_round=local_round_idx)
  234. self.update_nn_weights(grads, data_loader, epoch_idx, decay=self.comm_eff)
  235. if local_round_idx == 0:
  236. loss = self.compute_loss(host_components, epoch_idx, len(data_loader.get_overlap_indexes()))
  237. if local_round_idx + 1 != self.local_round:
  238. self.phi, self.overlap_ua = self.compute_phi_and_overlap_ua(data_loader)
  239. self.callback_metric("loss", "train", [Metric(epoch_idx, loss)])
  240. self.history_loss.append(loss)
  241. # updating variables for next epochs
  242. if epoch_idx + 1 == self.epochs:
  243. # only need to update phi in last epochs
  244. self.phi, _ = self.compute_phi_and_overlap_ua(data_loader)
  245. else:
  246. # compute phi, phi_product, overlap_ua etc. for next epoch
  247. self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components(
  248. data_loader)
  249. self.callback_list.on_epoch_end(epoch_idx)
  250. # check n_iter_no_change
  251. if self.n_iter_no_change is True:
  252. if self.check_convergence(loss):
  253. self.sync_stop_flag(epoch_idx, stop_flag=True)
  254. break
  255. else:
  256. self.sync_stop_flag(epoch_idx, stop_flag=False)
  257. LOGGER.debug('fitting epoch {} done, loss is {}'.format(epoch_idx, loss))
  258. self.callback_list.on_train_end()
  259. self.callback_meta("loss",
  260. "train",
  261. MetricMeta(name="train",
  262. metric_type="LOSS",
  263. extra_metas={"Best": min(self.history_loss)}))
  264. self.set_summary(self.generate_summary())
  265. LOGGER.debug('fitting ftl model done')
  266. def predict(self, data_inst):
  267. LOGGER.debug('guest start to predict')
  268. data_loader_key = self.get_dataset_key(data_inst)
  269. data_inst_ = data_overview.header_alignment(data_inst, self.store_header)
  270. if data_loader_key in self.cache_dataloader:
  271. data_loader = self.cache_dataloader[data_loader_key]
  272. else:
  273. data_loader, _, _, _ = self.prepare_data(self.init_intersect_obj(), data_inst_, guest_side=True)
  274. self.cache_dataloader[data_loader_key] = data_loader
  275. LOGGER.debug('try to get predict u from host, suffix is {}'.format((0, 'host_u')))
  276. host_predicts = self.transfer_variable.predict_host_u.get(idx=0, suffix=(0, 'host_u'))
  277. predict_score = np.matmul(host_predicts, self.phi.transpose())
  278. predicts = self.sigmoid(predict_score) # convert to predict scores
  279. predicts = list(map(float, predicts))
  280. predict_tb = session.parallelize(zip(data_loader.get_overlap_keys(), predicts,), include_key=True,
  281. partition=data_inst.partitions)
  282. threshold = self.predict_param.threshold
  283. predict_result = self.predict_score_to_output(data_inst_, predict_tb, classes=[0, 1], threshold=threshold)
  284. LOGGER.debug('ftl guest prediction done')
  285. return predict_result
  286. def export_model(self):
  287. model_param = self.get_model_param()
  288. model_param.phi_a.extend(self.phi.tolist()[0])
  289. return {"FTLGuestMeta": self.get_model_meta(), "FTLHostParam": model_param}
  290. def load_model(self, model_dict):
  291. model_param = None
  292. model_meta = None
  293. for _, value in model_dict["model"].items():
  294. for model in value:
  295. if model.endswith("Meta"):
  296. model_meta = value[model]
  297. if model.endswith("Param"):
  298. model_param = value[model]
  299. LOGGER.info("load model")
  300. self.set_model_meta(model_meta)
  301. self.set_model_param(model_param)
  302. self.phi = np.array([model_param.phi_a])