123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- import numpy as np
- from fate_arch.session import computing_session as session
- from federatedml.util import consts
- from federatedml.transfer_learning.hetero_ftl.ftl_base import FTL
- from federatedml.util import LOGGER
- from federatedml.transfer_learning.hetero_ftl.ftl_dataloder import FTLDataLoader
- from federatedml.statistic.intersect import RsaIntersectionGuest
- from federatedml.model_base import Metric
- from federatedml.model_base import MetricMeta
- from federatedml.optim.convergence import converge_func_factory
- from federatedml.secureprotol.paillier_tensor import PaillierTensor
- from federatedml.optim.activation import sigmoid
- from federatedml.statistic import data_overview
- class FTLGuest(FTL):
- def __init__(self):
- super(FTLGuest, self).__init__()
- self.phi = None # Φ_A
- self.phi_product = None # (Φ_A)‘(Φ_A) [feature_dim, feature_dim]
- self.overlap_y = None # y_i ∈ N_c
- self.overlap_y_2 = None # (y_i ∈ N_c )^2
- self.overlap_ua = None # u_i ∈ N_AB
- self.constant_k = None # κ
- self.feat_dim = None # output feature dimension
- self.send_components = None # components to send
- self.convergence = None
- self.overlap_y_pt = None # paillier tensor
- self.history_loss = [] # list to record history loss
- self.role = consts.GUEST
- def init_intersect_obj(self):
- intersect_obj = RsaIntersectionGuest()
- intersect_obj.guest_party_id = self.component_properties.local_partyid
- intersect_obj.host_party_id_list = self.component_properties.host_party_idlist
- intersect_obj.load_params(self.intersect_param)
- LOGGER.debug('intersect done')
- return intersect_obj
- def check_convergence(self, loss):
- LOGGER.info("check convergence")
- if self.convergence is None:
- self.convergence = converge_func_factory("diff", self.tol)
- return self.convergence.is_converge(loss)
- def compute_phi_and_overlap_ua(self, data_loader: FTLDataLoader):
- """
- compute Φ and ua of overlap samples
- """
- phi = None # [1, feature_dim] Φ_A
- overlap_ua = []
- for i in range(len(data_loader)):
- batch_x, batch_y = data_loader[i]
- ua_batch = self.nn.predict(batch_x) # [batch_size, feature_dim]
- relative_overlap_index = data_loader.get_relative_overlap_index(i)
- if len(relative_overlap_index) != 0:
- if self.verbose:
- LOGGER.debug('batch {}/{} overlap index is {}'.format(i, len(data_loader), relative_overlap_index))
- overlap_ua.append(ua_batch[relative_overlap_index])
- phi_tmp = np.expand_dims(np.sum(batch_y * ua_batch, axis=0), axis=0)
- if phi is None:
- phi = phi_tmp
- else:
- phi += phi_tmp
- phi = phi / self.data_num
- return phi, overlap_ua
- def batch_compute_components(self, data_loader: FTLDataLoader):
- """
- compute guest components
- """
- phi, overlap_ua = self.compute_phi_and_overlap_ua(data_loader) # Φ_A [1, feature_dim]
- phi_product = np.matmul(phi.transpose(), phi) # (Φ_A)‘(Φ_A) [feature_dim, feature_dim]
- if self.overlap_y is None:
- self.overlap_y = data_loader.get_overlap_y() # {C(y)=y} [1, feat_dim]
- if self.overlap_y_2 is None:
- self.overlap_y_2 = self.overlap_y * self.overlap_y # {D(y)=y^2} # [1, feat_dim]
- overlap_ua = np.concatenate(overlap_ua, axis=0) # [overlap_num, feat_dim]
- # 3 components will be sent to host
- y_overlap_2_phi_2 = 0.25 * np.expand_dims(self.overlap_y_2, axis=2) * phi_product
- y_overlap_phi = -0.5 * self.overlap_y * phi
- mapping_comp_a = -overlap_ua * self.constant_k
- return phi, phi_product, overlap_ua, [y_overlap_2_phi_2, y_overlap_phi, mapping_comp_a]
- def exchange_components(self, comp_to_send, epoch_idx):
- """
- send guest components and get host components
- """
- if self.mode == 'encrypted':
- comp_to_send = self.encrypt_tensor(comp_to_send)
- # sending [y_overlap_2_phi_2, y_overlap_phi, mapping_comp_a]
- self.transfer_variable.y_overlap_2_phi_2.remote(comp_to_send[0], suffix=(epoch_idx, ))
- self.transfer_variable.y_overlap_phi.remote(comp_to_send[1], suffix=(epoch_idx, ))
- self.transfer_variable.mapping_comp_a.remote(comp_to_send[2], suffix=(epoch_idx, ))
- # receiving [overlap_ub, overlap_ub_2, mapping_comp_b]
- overlap_ub = self.transfer_variable.overlap_ub.get(idx=0, suffix=(epoch_idx, ))
- overlap_ub_2 = self.transfer_variable.overlap_ub_2.get(idx=0, suffix=(epoch_idx, ))
- mapping_comp_b = self.transfer_variable.mapping_comp_b.get(idx=0, suffix=(epoch_idx, ))
- host_components = [overlap_ub, overlap_ub_2, mapping_comp_b]
- if self.mode == 'encrypted':
- host_paillier_tensors = [PaillierTensor(tb, partitions=self.partitions) for tb in host_components]
- return host_paillier_tensors
- else:
- return host_components
- def decrypt_inter_result(self, encrypted_const, grad_a_overlap, epoch_idx, local_round=-1):
- """
- add random mask to encrypted inter-result, get decrypted data from host add subtract random mask
- """
- rand_0 = self.rng_generator.generate_random_number(encrypted_const.shape)
- encrypted_const = encrypted_const + rand_0
- rand_1 = PaillierTensor(
- self.rng_generator.generate_random_number(
- grad_a_overlap.shape),
- partitions=self.partitions)
- grad_a_overlap = grad_a_overlap + rand_1
- self.transfer_variable.guest_side_const.remote(encrypted_const, suffix=(epoch_idx,
- local_round,))
- self.transfer_variable.guest_side_gradients.remote(grad_a_overlap.get_obj(), suffix=(epoch_idx,
- local_round,))
- const = self.transfer_variable.decrypted_guest_const.get(suffix=(epoch_idx, local_round, ), idx=0)
- grad = self.transfer_variable.decrypted_guest_gradients.get(suffix=(epoch_idx, local_round, ), idx=0)
- const = const - rand_0
- grad_a_overlap = PaillierTensor(grad, partitions=self.partitions) - rand_1
- return const, grad_a_overlap
- def decrypt_host_data(self, epoch_idx, local_round=-1):
- inter_grad = self.transfer_variable.host_side_gradients.get(suffix=(epoch_idx,
- local_round,
- 'host_de_send'), idx=0)
- inter_grad_pt = PaillierTensor(inter_grad, partitions=self.partitions)
- self.transfer_variable.decrypted_host_gradients.remote(inter_grad_pt.decrypt(self.encrypter).get_obj(),
- suffix=(epoch_idx,
- local_round,
- 'host_de_get'))
- def decrypt_loss_val(self, encrypted_loss, epoch_idx):
- self.transfer_variable.encrypted_loss.remote(encrypted_loss, suffix=(epoch_idx, 'send_loss'))
- decrypted_loss = self.transfer_variable.decrypted_loss.get(idx=0, suffix=(epoch_idx, 'get_loss'))
- return decrypted_loss
- def compute_backward_gradients(self, host_components, data_loader: FTLDataLoader, epoch_idx, local_round=-1):
- """
- compute backward gradients using host components
- """
- # they are Paillier tensors or np array
- overlap_ub, overlap_ub_2, mapping_comp_b = host_components[0], host_components[1], host_components[2]
- y_overlap_2_phi = np.expand_dims(self.overlap_y_2 * self.phi, axis=1)
- if self.mode == 'plain':
- loss_grads_const_part1 = 0.25 * np.squeeze(np.matmul(y_overlap_2_phi, overlap_ub_2), axis=1)
- loss_grads_const_part2 = self.overlap_y * overlap_ub
- const = np.sum(loss_grads_const_part1, axis=0) - 0.5 * np.sum(loss_grads_const_part2, axis=0)
- grad_a_nonoverlap = self.alpha * const * \
- data_loader.y[data_loader.get_non_overlap_indexes()] / self.data_num
- grad_a_overlap = self.alpha * const * self.overlap_y / self.data_num + mapping_comp_b
- return np.concatenate([grad_a_overlap, grad_a_nonoverlap], axis=0)
- elif self.mode == 'encrypted':
- loss_grads_const_part1 = overlap_ub_2.matmul_3d(0.25 * y_overlap_2_phi, multiply='right')
- loss_grads_const_part1 = loss_grads_const_part1.squeeze(axis=1)
- if self.overlap_y_pt is None:
- self.overlap_y_pt = PaillierTensor(self.overlap_y, partitions=self.partitions)
- loss_grads_const_part2 = overlap_ub * self.overlap_y_pt
- encrypted_const = loss_grads_const_part1.reduce_sum() - 0.5 * loss_grads_const_part2.reduce_sum()
- grad_a_overlap = self.overlap_y_pt.map_ndarray_product(
- (self.alpha / self.data_num * encrypted_const)) + mapping_comp_b
- const, grad_a_overlap = self.decrypt_inter_result(
- encrypted_const, grad_a_overlap, epoch_idx=epoch_idx, local_round=local_round)
- self.decrypt_host_data(epoch_idx, local_round=local_round)
- grad_a_nonoverlap = self.alpha * const * \
- data_loader.y[data_loader.get_non_overlap_indexes()] / self.data_num
- return np.concatenate([grad_a_overlap.numpy(), grad_a_nonoverlap], axis=0)
- def compute_loss(self, host_components, epoch_idx, overlap_num):
- """
- compute training loss
- """
- overlap_ub, overlap_ub_2, mapping_comp_b = host_components[0], host_components[1], host_components[2]
- if self.mode == 'plain':
- loss_overlap = np.sum((-self.overlap_ua * self.constant_k) * overlap_ub)
- ub_phi = np.matmul(overlap_ub, self.phi.transpose())
- part1 = -0.5 * np.sum(self.overlap_y * ub_phi)
- part2 = 1.0 / 8 * np.sum(ub_phi * ub_phi)
- part3 = len(self.overlap_y) * np.log(2)
- loss_y = part1 + part2 + part3
- return self.alpha * (loss_y / overlap_num) + loss_overlap / overlap_num
- elif self.mode == 'encrypted':
- loss_overlap = overlap_ub.element_wise_product((-self.overlap_ua * self.constant_k))
- sum = np.sum(loss_overlap.reduce_sum())
- ub_phi = overlap_ub.T.fast_matmul_2d(self.phi.transpose())
- part1 = -0.5 * np.sum((self.overlap_y * ub_phi))
- ub_2 = overlap_ub_2.reduce_sum()
- enc_phi_uB_2_phi = np.matmul(np.matmul(self.phi, ub_2), self.phi.transpose())
- part2 = 1 / 8 * np.sum(enc_phi_uB_2_phi)
- part3 = len(self.overlap_y) * np.log(2)
- loss_y = part1 + part2 + part3
- en_loss = (self.alpha / self.overlap_num) * loss_y + sum / overlap_num
- loss_val = self.decrypt_loss_val(en_loss, epoch_idx)
- return loss_val
- @staticmethod
- def sigmoid(x):
- return np.array(list(map(sigmoid, x)))
- def generate_summary(self):
- summary = {'loss_history': self.history_loss,
- "best_iteration": self.callback_variables.best_iteration}
- summary['validation_metrics'] = self.callback_variables.validation_summary
- return summary
- def check_host_number(self):
- host_num = len(self.component_properties.host_party_idlist)
- LOGGER.info('host number is {}'.format(host_num))
- if host_num != 1:
- raise ValueError('only 1 host party is allowed')
- def fit(self, data_inst, validate_data=None):
- LOGGER.debug('in training, partitions is {}'.format(data_inst.partitions))
- LOGGER.info('start to fit a ftl model, '
- 'run mode is {},'
- 'communication efficient mode is {}'.format(self.mode, self.comm_eff))
- self.check_host_number()
- data_loader, self.x_shape, self.data_num, self.overlap_num = self.prepare_data(self.init_intersect_obj(),
- data_inst, guest_side=True)
- self.input_dim = self.x_shape[0]
- # cache data_loader for faster validation
- self.cache_dataloader[self.get_dataset_key(data_inst)] = data_loader
- self.partitions = data_inst.partitions
- LOGGER.debug('self partitions is {}'.format(self.partitions))
- self.initialize_nn(input_shape=self.x_shape)
- self.feat_dim = self.nn._model.output_shape[1]
- self.constant_k = 1 / self.feat_dim
- self.callback_list.on_train_begin(train_data=data_inst, validate_data=validate_data)
- self.callback_meta("loss",
- "train",
- MetricMeta(name="train",
- metric_type="LOSS",
- extra_metas={"unit_name": "iters"}))
- # compute intermediate result of first epoch
- self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components(data_loader)
- for epoch_idx in range(self.epochs):
- LOGGER.debug('fitting epoch {}'.format(epoch_idx))
- self.callback_list.on_epoch_begin(epoch_idx)
- host_components = self.exchange_components(self.send_components, epoch_idx=epoch_idx)
- loss = None
- for local_round_idx in range(self.local_round):
- if self.comm_eff:
- LOGGER.debug('running local iter {}'.format(local_round_idx))
- grads = self.compute_backward_gradients(host_components, data_loader, epoch_idx=epoch_idx,
- local_round=local_round_idx)
- self.update_nn_weights(grads, data_loader, epoch_idx, decay=self.comm_eff)
- if local_round_idx == 0:
- loss = self.compute_loss(host_components, epoch_idx, len(data_loader.get_overlap_indexes()))
- if local_round_idx + 1 != self.local_round:
- self.phi, self.overlap_ua = self.compute_phi_and_overlap_ua(data_loader)
- self.callback_metric("loss", "train", [Metric(epoch_idx, loss)])
- self.history_loss.append(loss)
- # updating variables for next epochs
- if epoch_idx + 1 == self.epochs:
- # only need to update phi in last epochs
- self.phi, _ = self.compute_phi_and_overlap_ua(data_loader)
- else:
- # compute phi, phi_product, overlap_ua etc. for next epoch
- self.phi, self.phi_product, self.overlap_ua, self.send_components = self.batch_compute_components(
- data_loader)
- self.callback_list.on_epoch_end(epoch_idx)
- # check n_iter_no_change
- if self.n_iter_no_change is True:
- if self.check_convergence(loss):
- self.sync_stop_flag(epoch_idx, stop_flag=True)
- break
- else:
- self.sync_stop_flag(epoch_idx, stop_flag=False)
- LOGGER.debug('fitting epoch {} done, loss is {}'.format(epoch_idx, loss))
- self.callback_list.on_train_end()
- self.callback_meta("loss",
- "train",
- MetricMeta(name="train",
- metric_type="LOSS",
- extra_metas={"Best": min(self.history_loss)}))
- self.set_summary(self.generate_summary())
- LOGGER.debug('fitting ftl model done')
- def predict(self, data_inst):
- LOGGER.debug('guest start to predict')
- data_loader_key = self.get_dataset_key(data_inst)
- data_inst_ = data_overview.header_alignment(data_inst, self.store_header)
- if data_loader_key in self.cache_dataloader:
- data_loader = self.cache_dataloader[data_loader_key]
- else:
- data_loader, _, _, _ = self.prepare_data(self.init_intersect_obj(), data_inst_, guest_side=True)
- self.cache_dataloader[data_loader_key] = data_loader
- LOGGER.debug('try to get predict u from host, suffix is {}'.format((0, 'host_u')))
- host_predicts = self.transfer_variable.predict_host_u.get(idx=0, suffix=(0, 'host_u'))
- predict_score = np.matmul(host_predicts, self.phi.transpose())
- predicts = self.sigmoid(predict_score) # convert to predict scores
- predicts = list(map(float, predicts))
- predict_tb = session.parallelize(zip(data_loader.get_overlap_keys(), predicts,), include_key=True,
- partition=data_inst.partitions)
- threshold = self.predict_param.threshold
- predict_result = self.predict_score_to_output(data_inst_, predict_tb, classes=[0, 1], threshold=threshold)
- LOGGER.debug('ftl guest prediction done')
- return predict_result
- def export_model(self):
- model_param = self.get_model_param()
- model_param.phi_a.extend(self.phi.tolist()[0])
- return {"FTLGuestMeta": self.get_model_meta(), "FTLHostParam": model_param}
- def load_model(self, model_dict):
- model_param = None
- model_meta = None
- for _, value in model_dict["model"].items():
- for model in value:
- if model.endswith("Meta"):
- model_meta = value[model]
- if model.endswith("Param"):
- model_param = value[model]
- LOGGER.info("load model")
- self.set_model_meta(model_meta)
- self.set_model_param(model_param)
- self.phi = np.array([model_param.phi_a])
|