123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- import numpy as np
- from federatedml.transfer_learning.hetero_ftl.ftl_base import FTL
- from federatedml.statistic.intersect import RsaIntersectionHost
- from federatedml.util import LOGGER
- from federatedml.transfer_learning.hetero_ftl.ftl_dataloder import FTLDataLoader
- from federatedml.util import consts
- from federatedml.secureprotol.paillier_tensor import PaillierTensor
- from federatedml.util.io_check import assert_io_num_rows_equal
- from federatedml.statistic import data_overview
- class FTLHost(FTL):
- def __init__(self):
- super(FTLHost, self).__init__()
- self.overlap_ub = None # u_b
- self.overlap_ub_2 = None # u_b squared
- self.mapping_comp_b = None
- self.constant_k = None # κ
- self.feat_dim = None # output feature dimension
- self.m_b = None # random mask
- self.role = consts.HOST
- def init_intersect_obj(self):
- LOGGER.debug('creating intersect obj done')
- intersect_obj = RsaIntersectionHost()
- intersect_obj.host_party_id = self.component_properties.local_partyid
- intersect_obj.host_party_id_list = self.component_properties.host_party_idlist
- intersect_obj.load_params(self.intersect_param)
- return intersect_obj
- def batch_compute_components(self, data_loader: FTLDataLoader):
- """
- compute host components
- """
- overlap_ub = []
- for i in range(len(data_loader)):
- batch_x = data_loader[i]
- ub_batch = self.nn.predict(batch_x)
- overlap_ub.append(ub_batch)
- overlap_ub = np.concatenate(overlap_ub, axis=0)
- overlap_ub_2 = np.matmul(np.expand_dims(overlap_ub, axis=2), np.expand_dims(overlap_ub, axis=1))
- mapping_comp_b = - overlap_ub * self.constant_k
- if self.verbose:
- LOGGER.debug('overlap_ub is {}'.format(overlap_ub))
- LOGGER.debug('overlap_ub_2 is {}'.format(overlap_ub_2))
- return overlap_ub, overlap_ub_2, mapping_comp_b
- def exchange_components(self, comp_to_send, epoch_idx):
- """
- compute host components and sent to guest
- """
- if self.mode == 'encrypted':
- comp_to_send = self.encrypt_tensor(comp_to_send)
- # receiving guest components
- y_overlap_2_phi_2 = self.transfer_variable.y_overlap_2_phi_2.get(idx=0, suffix=(epoch_idx, ))
- y_overlap_phi = self.transfer_variable.y_overlap_phi.get(idx=0, suffix=(epoch_idx, ))
- mapping_comp_a = self.transfer_variable.mapping_comp_a.get(idx=0, suffix=(epoch_idx, ))
- guest_components = [y_overlap_2_phi_2, y_overlap_phi, mapping_comp_a]
- # sending host components
- self.transfer_variable.overlap_ub.remote(comp_to_send[0], suffix=(epoch_idx, ))
- self.transfer_variable.overlap_ub_2.remote(comp_to_send[1], suffix=(epoch_idx, ))
- self.transfer_variable.mapping_comp_b.remote(comp_to_send[2], suffix=(epoch_idx, ))
- if self.mode == 'encrypted':
- guest_paillier_tensors = [PaillierTensor(tb, partitions=self.partitions) for tb in guest_components]
- return guest_paillier_tensors
- else:
- return guest_components
- def decrypt_guest_data(self, epoch_idx, local_round=-1):
- encrypted_consts = self.transfer_variable.guest_side_const.get(suffix=(epoch_idx, local_round, ),
- idx=0)
- grad_table = self.transfer_variable.guest_side_gradients.get(suffix=(epoch_idx, local_round, ),
- idx=0)
- inter_grad = PaillierTensor(grad_table, partitions=self.partitions)
- decrpyted_grad = inter_grad.decrypt(self.encrypter)
- decrypted_const = self.encrypter.recursive_decrypt(encrypted_consts)
- self.transfer_variable.decrypted_guest_const.remote(decrypted_const,
- suffix=(epoch_idx, local_round, ))
- self.transfer_variable.decrypted_guest_gradients.remote(decrpyted_grad.get_obj(),
- suffix=(epoch_idx, local_round, ))
- def decrypt_inter_result(self, loss_grad_b, epoch_idx, local_round=-1):
- rand_0 = PaillierTensor(
- self.rng_generator.generate_random_number(
- loss_grad_b.shape),
- partitions=self.partitions)
- grad_a_overlap = loss_grad_b + rand_0
- self.transfer_variable.host_side_gradients.remote(grad_a_overlap.get_obj(),
- suffix=(epoch_idx, local_round, 'host_de_send'))
- de_loss_grad_b = self.transfer_variable.decrypted_host_gradients\
- .get(suffix=(epoch_idx, local_round, 'host_de_get'), idx=0)
- de_loss_grad_b = PaillierTensor(de_loss_grad_b, partitions=self.partitions) - rand_0
- return de_loss_grad_b
- def compute_backward_gradients(self, guest_components, data_loader: FTLDataLoader, epoch_idx, local_round=-1):
- """
- compute host bottom model gradients
- """
- y_overlap_2_phi_2, y_overlap_phi, mapping_comp_a = guest_components[0], guest_components[1], guest_components[2]
- ub_overlap_ex = np.expand_dims(self.overlap_ub, axis=1)
- if self.mode == 'plain':
- ub_overlap_y_overlap_2_phi_2 = np.matmul(ub_overlap_ex, y_overlap_2_phi_2)
- l1_grad_b = np.squeeze(ub_overlap_y_overlap_2_phi_2, axis=1) + y_overlap_phi
- loss_grad_b = self.alpha * l1_grad_b + mapping_comp_a
- return loss_grad_b
- if self.mode == 'encrypted':
- ub_overlap_ex = np.expand_dims(self.overlap_ub, axis=1)
- ub_overlap_y_overlap_2_phi_2 = y_overlap_2_phi_2.matmul_3d(ub_overlap_ex, multiply='right')
- ub_overlap_y_overlap_2_phi_2 = ub_overlap_y_overlap_2_phi_2.squeeze(axis=1)
- l1_grad_b = ub_overlap_y_overlap_2_phi_2 + y_overlap_phi
- en_loss_grad_b = l1_grad_b * self.alpha + mapping_comp_a
- self.decrypt_guest_data(epoch_idx, local_round=local_round)
- loss_grad_b = self.decrypt_inter_result(en_loss_grad_b, epoch_idx, local_round=local_round)
- return loss_grad_b.numpy()
- def compute_loss(self, epoch_idx):
- """
- help guest compute ftl loss. plain mode will skip/ in encrypted mode will decrypt received loss
- """
- if self.mode == 'plain':
- return
- elif self.mode == 'encrypted':
- encrypted_loss = self.transfer_variable.encrypted_loss.get(idx=0, suffix=(epoch_idx, 'send_loss'))
- rs = self.encrypter.recursive_decrypt(encrypted_loss)
- self.transfer_variable.decrypted_loss.remote(rs, suffix=(epoch_idx, 'get_loss'))
- def fit(self, data_inst, validate_data=None):
- LOGGER.info('start to fit a ftl model, '
- 'run mode is {},'
- 'communication efficient mode is {}'.format(self.mode, self.comm_eff))
- data_loader, self.x_shape, self.data_num, self.overlap_num = self.prepare_data(self.init_intersect_obj(),
- data_inst, guest_side=False)
- self.input_dim = self.x_shape[0]
- # cache data_loader for faster validation
- self.cache_dataloader[self.get_dataset_key(data_inst)] = data_loader
- self.partitions = data_inst.partitions
- self.initialize_nn(input_shape=self.x_shape)
- self.feat_dim = self.nn._model.output_shape[1]
- self.constant_k = 1 / self.feat_dim
- self.callback_list.on_train_begin(data_inst, validate_data)
- for epoch_idx in range(self.epochs):
- LOGGER.debug('fitting epoch {}'.format(epoch_idx))
- self.callback_list.on_epoch_begin(epoch_idx)
- self.overlap_ub, self.overlap_ub_2, self.mapping_comp_b = self.batch_compute_components(data_loader)
- send_components = [self.overlap_ub, self.overlap_ub_2, self.mapping_comp_b]
- guest_components = self.exchange_components(send_components, epoch_idx)
- for local_round_idx in range(self.local_round):
- if self.comm_eff:
- LOGGER.debug('running local iter {}'.format(local_round_idx))
- grads = self.compute_backward_gradients(guest_components, data_loader, epoch_idx,
- local_round=local_round_idx)
- self.update_nn_weights(grads, data_loader, epoch_idx, decay=self.comm_eff)
- if local_round_idx == 0:
- self.compute_loss(epoch_idx)
- if local_round_idx + 1 != self.local_round:
- self.overlap_ub, self.overlap_ub_2, self.mapping_comp_b = self.batch_compute_components(data_loader)
- self.callback_list.on_epoch_end(epoch_idx)
- if self.n_iter_no_change is True:
- stop_flag = self.sync_stop_flag(epoch_idx)
- if stop_flag:
- break
- LOGGER.debug('fitting epoch {} done'.format(epoch_idx))
- self.callback_list.on_train_end()
- self.set_summary(self.generate_summary())
- def generate_summary(self):
- summary = {"best_iteration": self.callback_variables.best_iteration}
- return summary
- @assert_io_num_rows_equal
- def predict(self, data_inst):
- LOGGER.debug('host start to predict')
- self.transfer_variable.predict_host_u.disable_auto_clean()
- data_loader_key = self.get_dataset_key(data_inst)
- data_inst_ = data_overview.header_alignment(data_inst, self.store_header)
- if data_loader_key in self.cache_dataloader:
- data_loader = self.cache_dataloader[data_loader_key]
- else:
- data_loader, _, _, _ = self.prepare_data(self.init_intersect_obj(), data_inst_, guest_side=False)
- self.cache_dataloader[data_loader_key] = data_loader
- ub_batches = []
- for i in range(len(data_loader)):
- batch_x = data_loader[i]
- ub_batch = self.nn.predict(batch_x)
- ub_batches.append(ub_batch)
- predicts = np.concatenate(ub_batches, axis=0)
- self.transfer_variable.predict_host_u.remote(predicts, suffix=(0, 'host_u'))
- LOGGER.debug('ftl host prediction done')
- return None
- def export_model(self):
- return {"FTLHostMeta": self.get_model_meta(), "FTLHostParam": self.get_model_param()}
- def load_model(self, model_dict):
- model_param = None
- model_meta = None
- for _, value in model_dict["model"].items():
- for model in value:
- if model.endswith("Meta"):
- model_meta = value[model]
- if model.endswith("Param"):
- model_param = value[model]
- LOGGER.info("load model")
- self.set_model_meta(model_meta)
- self.set_model_param(model_param)
|