123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331 |
- import copy
- import json
- import functools
- import numpy as np
- from federatedml.util import LOGGER
- from federatedml.transfer_learning.hetero_ftl.backend.nn_model import get_nn_builder
- from federatedml.model_base import ModelBase
- from federatedml.param.ftl_param import FTLParam
- from federatedml.transfer_learning.hetero_ftl.backend.tf_keras.nn_model import KerasNNModel
- from federatedml.util.classify_label_checker import ClassifyLabelChecker
- from federatedml.transfer_variable.transfer_class.ftl_transfer_variable import FTLTransferVariable
- from federatedml.transfer_learning.hetero_ftl.ftl_dataloder import FTLDataLoader
- from federatedml.transfer_learning.hetero_ftl.backend.tf_keras.data_generator import KerasSequenceDataConverter
- from federatedml.nn.backend.utils import rng as random_number_generator
- from federatedml.secureprotol import PaillierEncrypt
- from federatedml.util import consts
- from federatedml.secureprotol.paillier_tensor import PaillierTensor
- from federatedml.protobuf.generated.ftl_model_param_pb2 import FTLModelParam
- from federatedml.protobuf.generated.ftl_model_meta_pb2 import FTLModelMeta, FTLPredictParam, FTLOptimizerParam
- class FTL(ModelBase):
- def __init__(self):
- super(FTL, self).__init__()
- # input para
- self.nn_define = None
- self.alpha = None
- self.tol = None
- self.learning_rate = None
- self.n_iter_no_change = None
- self.validation_freqs = None
- self.early_stopping_rounds = None
- self.use_first_metric_only = None
- self.optimizer = None
- self.intersect_param = None
- self.config_type = 'keras'
- self.comm_eff = None
- self.local_round = 1
- # runtime variable
- self.verbose = False
- self.nn: KerasNNModel = None
- self.nn_builder = None
- self.model_param = FTLParam()
- self.x_shape = None
- self.input_dim = None
- self.data_num = 0
- self.overlap_num = 0
- self.transfer_variable = FTLTransferVariable()
- self.data_convertor = KerasSequenceDataConverter()
- self.mode = 'plain'
- self.encrypter = None
- self.partitions = 16
- self.batch_size = None
- self.epochs = None
- self.store_header = None # header of input data table
- self.model_float_type = np.float32
- self.cache_dataloader = {}
- self.validation_strategy = None
- def _init_model(self, param: FTLParam):
- self.nn_define = param.nn_define
- self.alpha = param.alpha
- self.tol = param.tol
- self.n_iter_no_change = param.n_iter_no_change
- self.validation_freqs = param.validation_freqs
- self.optimizer = param.optimizer
- self.intersect_param = param.intersect_param
- self.batch_size = param.batch_size
- self.epochs = param.epochs
- self.mode = param.mode
- self.comm_eff = param.communication_efficient
- self.local_round = param.local_round
- assert 'learning_rate' in self.optimizer.kwargs, 'optimizer setting must contain learning_rate'
- self.learning_rate = self.optimizer.kwargs['learning_rate']
- if not self.comm_eff:
- self.local_round = 1
- LOGGER.debug('communication efficient mode is not enabled, local_round set as 1')
- self.encrypter = self.generate_encrypter(param)
- self.predict_param = param.predict_param
- self.rng_generator = random_number_generator.RandomNumberGenerator()
- @staticmethod
- def debug_data_inst(data_inst):
- collect_data = list(data_inst.collect())
- LOGGER.debug('showing Table')
- for d in collect_data:
- LOGGER.debug('key {} id {}, features {} label {}'.format(d[0], d[1].inst_id, d[1].features, d[1].label))
- @staticmethod
- def reset_label(inst, mapping):
- new_inst = copy.deepcopy(inst)
- new_inst.label = mapping[new_inst.label]
- return new_inst
- @staticmethod
- def check_label(data_inst):
- """
- check label. FTL only supports binary classification, and labels should be 1 or -1
- """
- LOGGER.debug('checking label')
- label_checker = ClassifyLabelChecker()
- num_class, class_set = label_checker.validate_label(data_inst)
- if num_class != 2:
- raise ValueError(
- 'ftl only support binary classification, however {} labels are provided.'.format(num_class))
- if 1 in class_set and -1 in class_set:
- return data_inst
- else:
- soreted_class_set = sorted(list(class_set))
- new_label_mapping = {soreted_class_set[1]: 1, soreted_class_set[0]: -1}
- reset_label = functools.partial(FTL.reset_label, mapping=new_label_mapping)
- new_table = data_inst.mapValues(reset_label)
- new_table.schema = copy.deepcopy(data_inst.schema)
- return new_table
- def generate_encrypter(self, param) -> PaillierEncrypt:
- LOGGER.info("generate encrypter")
- if param.encrypt_param.method.lower() == consts.PAILLIER.lower():
- encrypter = PaillierEncrypt()
- encrypter.generate_key(param.encrypt_param.key_length)
- else:
- raise NotImplementedError("encrypt method not supported yet!!!")
- return encrypter
- def encrypt_tensor(self, components, return_dtable=True):
- """
- transform numpy array into Paillier tensor and encrypt
- """
- encrypted_tensors = []
- for comp in components:
- encrypted_tensor = PaillierTensor(comp, partitions=self.partitions)
- if return_dtable:
- encrypted_tensors.append(encrypted_tensor.encrypt(self.encrypter).get_obj())
- else:
- encrypted_tensors.append(encrypted_tensor.encrypt(self.encrypter))
- return encrypted_tensors
- def learning_rate_decay(self, learning_rate, epoch):
- """
- learning_rate decay
- """
- return learning_rate * 1 / np.sqrt(epoch + 1)
- def sync_stop_flag(self, num_round, stop_flag=None):
- """
- stop flag for n_iter_no_change
- """
- LOGGER.info("sync stop flag, boosting round is {}".format(num_round))
- if self.role == consts.GUEST:
- self.transfer_variable.stop_flag.remote(stop_flag,
- role=consts.HOST,
- idx=-1,
- suffix=(num_round,))
- elif self.role == consts.HOST:
- return self.transfer_variable.stop_flag.get(idx=0, suffix=(num_round, ))
- def prepare_data(self, intersect_obj, data_inst, guest_side=False):
- """
- find intersect ids and prepare dataloader
- """
- if guest_side:
- data_inst = self.check_label(data_inst)
- overlap_samples = intersect_obj.run_intersect(data_inst) # find intersect ids
- overlap_samples = intersect_obj.get_value_from_data(overlap_samples, data_inst)
- non_overlap_samples = data_inst.subtractByKey(overlap_samples)
- LOGGER.debug('num of overlap/non-overlap sampels: {}/{}'.format(overlap_samples.count(),
- non_overlap_samples.count()))
- if overlap_samples.count() == 0:
- raise ValueError('no overlap samples')
- if guest_side and non_overlap_samples == 0:
- raise ValueError('overlap samples are required in guest side')
- self.store_header = data_inst.schema['header']
- LOGGER.debug('data inst header is {}'.format(self.store_header))
- LOGGER.debug('has {} overlap samples'.format(overlap_samples.count()))
- batch_size = self.batch_size
- if self.batch_size == -1:
- batch_size = data_inst.count() + 1 # make sure larger than sample number
- data_loader = FTLDataLoader(non_overlap_samples=non_overlap_samples,
- batch_size=batch_size, overlap_samples=overlap_samples, guest_side=guest_side)
- LOGGER.debug("data details are :{}".format(data_loader.data_basic_info()))
- return data_loader, data_loader.x_shape, data_inst.count(), len(data_loader.get_overlap_indexes())
- def get_model_float_type(self, nn):
- weights = nn.get_trainable_weights()
- self.model_float_type = weights[0].dtype
- def initialize_nn(self, input_shape):
- """
- initializing nn weights
- """
- loss = "keep_predict_loss"
- self.nn_builder = get_nn_builder(config_type=self.config_type)
- self.nn = self.nn_builder(loss=loss, nn_define=self.nn_define, optimizer=self.optimizer, metrics=None,
- input_shape=input_shape)
- self.get_model_float_type(self.nn)
- LOGGER.debug('printing nn layers structure')
- for layer in self.nn._model.layers:
- LOGGER.debug('input shape {}, output shape {}'.format(layer.input_shape, layer.output_shape))
- def generate_mask(self, shape):
- """
- generate random number mask
- """
- return self.rng_generator.generate_random_number(shape)
- def _batch_gradient_update(self, X, grads):
- """
- compute and update gradients for all samples
- """
- data = self.data_convertor.convert_data(X, grads)
- self.nn.train(data)
- def _get_mini_batch_gradient(self, X_batch, backward_grads_batch):
- """
- compute gradient for a mini batch
- """
- X_batch = X_batch.astype(self.model_float_type)
- backward_grads_batch = backward_grads_batch.astype(self.model_float_type)
- grads = self.nn.get_weight_gradients(X_batch, backward_grads_batch)
- return grads
- def update_nn_weights(self, backward_grads, data_loader: FTLDataLoader, epoch_idx, decay=False):
- """
- updating bottom nn model weights using backward gradients
- """
- LOGGER.debug('updating grads at epoch {}'.format(epoch_idx))
- assert len(data_loader.x) == len(backward_grads)
- weight_grads = []
- for i in range(len(data_loader)):
- start, end = data_loader.get_batch_indexes(i)
- batch_x = data_loader.x[start: end]
- batch_grads = backward_grads[start: end]
- batch_weight_grads = self._get_mini_batch_gradient(batch_x, batch_grads)
- if len(weight_grads) == 0:
- weight_grads.extend(batch_weight_grads)
- else:
- for w, bw in zip(weight_grads, batch_weight_grads):
- w += bw
- if decay:
- new_learning_rate = self.learning_rate_decay(self.learning_rate, epoch_idx)
- self.nn.set_learning_rate(new_learning_rate)
- LOGGER.debug('epoch {} optimizer details are {}'.format(epoch_idx, self.nn.export_optimizer_config()))
- self.nn.apply_gradients(weight_grads)
- def export_nn(self):
- return self.nn.export_model()
- @staticmethod
- def get_dataset_key(data_inst):
- return id(data_inst)
- def get_model_meta(self):
- model_meta = FTLModelMeta()
- model_meta.config_type = self.config_type
- model_meta.nn_define = json.dumps(self.nn_define)
- model_meta.batch_size = self.batch_size
- model_meta.epochs = self.epochs
- model_meta.tol = self.tol
- model_meta.input_dim = self.input_dim
- predict_param = FTLPredictParam()
- optimizer_param = FTLOptimizerParam()
- optimizer_param.optimizer = self.optimizer.optimizer
- optimizer_param.kwargs = json.dumps(self.optimizer.kwargs)
- model_meta.optimizer_param.CopyFrom(optimizer_param)
- model_meta.predict_param.CopyFrom(predict_param)
- return model_meta
- def get_model_param(self):
- model_param = FTLModelParam()
- model_bytes = self.nn.export_model()
- model_param.model_bytes = model_bytes
- model_param.header.extend(list(self.store_header))
- return model_param
- def set_model_meta(self, model_meta):
- self.config_type = model_meta.config_type
- self.nn_define = json.loads(model_meta.nn_define)
- self.batch_size = model_meta.batch_size
- self.epochs = model_meta.epochs
- self.tol = model_meta.tol
- self.optimizer = FTLParam()._parse_optimizer(FTLParam().optimizer)
- self.input_dim = model_meta.input_dim
- self.optimizer.optimizer = model_meta.optimizer_param.optimizer
- self.optimizer.kwargs = json.loads(model_meta.optimizer_param.kwargs)
- self.initialize_nn((self.input_dim,))
- def set_model_param(self, model_param):
- self.nn.restore_model(model_param.model_bytes)
- self.store_header = list(model_param.header)
- LOGGER.debug('stored header load, is {}'.format(self.store_header))
|