123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- import torch
- import torch as t
- import tqdm
- import numpy as np
- from torch.utils.data import DataLoader
- from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient as SecureAggClient
- from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorServer as SecureAggServer
- from federatedml.nn.dataset.base import Dataset
- from federatedml.nn.homo.trainer.trainer_base import TrainerBase
- from federatedml.util import LOGGER, consts
- from federatedml.optim.convergence import converge_func_factory
- class FedAVGTrainer(TrainerBase):
- """
- Parameters
- ----------
- epochs: int >0, epochs to train
- batch_size: int, -1 means full batch
- secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number
- mask to local models. These random number masks will eventually cancel out to get 0.
- weighted_aggregation: bool, whether add weight to each local model when doing aggregation.
- if True, According to origin paper, weight of a client is: n_local / n_global, where n_local
- is the sample number locally and n_global is the sample number of all clients.
- if False, simply averaging these models.
- early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between
- two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,
- stop training
- tol: float, tol value for early stop
- aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate
- every n epochs.
- cuda: bool, use cuda or not
- pin_memory: bool, for pytorch DataLoader
- shuffle: bool, for pytorch DataLoader
- data_loader_worker: int, for pytorch DataLoader, number of workers when loading data
- validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.
- if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'
- if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'
- if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'
- checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.
- task_type: str, 'auto', 'binary', 'multi', 'regression'
- this option decides the return format of this trainer, and the evaluation type when running validation.
- if auto, will automatically infer your task type from labels and predict results.
- """
- def __init__(self, epochs=10, batch_size=512,
- early_stop=None, tol=0.0001,
- secure_aggregate=True, weighted_aggregation=True, aggregate_every_n_epoch=None,
- cuda=False, pin_memory=True, shuffle=True, data_loader_worker=0,
- validation_freqs=None,
- checkpoint_save_freqs=None,
- task_type='auto'
- ):
- super(FedAVGTrainer, self).__init__()
-
- self.epochs = epochs
- self.tol = tol
- self.validation_freq = validation_freqs
- self.save_freq = checkpoint_save_freqs
- self.task_type = task_type
- task_type_allow = [
- consts.BINARY,
- consts.REGRESSION,
- consts.MULTY,
- 'auto']
- assert self.task_type in task_type_allow, 'task type must in {}'.format(
- task_type_allow)
-
- self.secure_aggregate = secure_aggregate
- self.weighted_aggregation = weighted_aggregation
- self.aggregate_every_n_epoch = aggregate_every_n_epoch
-
- self.cuda = cuda
- if not torch.cuda.is_available() and self.cuda:
- raise ValueError('Cuda is not available on this machine')
-
- self.batch_size = batch_size
- self.pin_memory = pin_memory
- self.shuffle = shuffle
- self.data_loader_worker = data_loader_worker
- self.early_stop = early_stop
- early_stop_type = ['diff', 'abs']
- if early_stop is not None:
- assert early_stop in early_stop_type, 'early stop type must be in {}, bug got {}' \
- .format(early_stop_type, early_stop)
-
- self.comm_suffix = 'fedavg'
-
- self.check_trainer_param([self.epochs,
- self.validation_freq,
- self.save_freq,
- self.aggregate_every_n_epoch],
- ['epochs',
- 'validation_freq',
- 'save_freq',
- 'aggregate_every_n_epoch'],
- self.is_pos_int,
- '{} is not a positive int')
- self.check_trainer_param([self.secure_aggregate, self.weighted_aggregation, self.pin_memory], [
- 'secure_aggregate', 'weighted_aggregation', 'pin_memory,'], self.is_bool, '{} is not a bool')
- self.check_trainer_param(
- [self.tol], ['tol'], self.is_float, '{} is not a float')
- def train(
- self,
- train_set: Dataset,
- validate_set: Dataset = None,
- optimizer: t.optim.Optimizer = None,
- loss=None,
- extra_dict={}):
- if self.cuda:
- self.model = self.model.cuda()
- if optimizer is None:
- raise ValueError(
- 'FedAVGTrainer requires an optimizer, but got None, please specify optimizer in the '
- 'job configuration')
- if loss is None:
- raise ValueError(
- 'FedAVGTrainer requires a loss function, but got None, please specify loss function in the'
- ' job configuration')
- if self.batch_size > len(train_set) or self.batch_size == -1:
- self.batch_size = len(train_set)
- dl = DataLoader(
- train_set,
- batch_size=self.batch_size,
- pin_memory=self.pin_memory,
- shuffle=self.shuffle,
- num_workers=self.data_loader_worker)
-
- cur_agg_round = 0
- if self.aggregate_every_n_epoch is not None:
- aggregate_round = self.epochs // self.aggregate_every_n_epoch
- else:
- aggregate_round = self.epochs
-
- if self.fed_mode:
- if self.weighted_aggregation:
- sample_num = len(train_set)
- else:
- sample_num = 1.0
- client_agg = SecureAggClient(
- True, aggregate_weight=sample_num, communicate_match_suffix=self.comm_suffix)
- else:
- client_agg = None
-
- cur_epoch = 0
- loss_history = []
- need_stop = False
- evaluation_summary = {}
-
- for i in range(self.epochs):
- cur_epoch = i
- LOGGER.info('epoch is {}'.format(i))
- epoch_loss = 0.0
- batch_idx = 0
- acc_num = 0
-
- if not self.fed_mode:
- to_iterate = tqdm.tqdm(dl)
- else:
- to_iterate = dl
- for batch_data, batch_label in to_iterate:
- if self.cuda:
- batch_data, batch_label = self.to_cuda(
- batch_data), self.to_cuda(batch_label)
- optimizer.zero_grad()
- pred = self.model(batch_data)
- batch_loss = loss(pred, batch_label)
- batch_loss.backward()
- optimizer.step()
- batch_loss_np = batch_loss.detach().numpy(
- ) if not self.cuda else batch_loss.cpu().detach().numpy()
- if acc_num + self.batch_size > len(train_set):
- batch_len = len(train_set) - acc_num
- else:
- batch_len = self.batch_size
- epoch_loss += batch_loss_np * batch_len
- batch_idx += 1
- if self.fed_mode:
- LOGGER.debug(
- 'epoch {} batch {} finished'.format(
- i, batch_idx))
-
- epoch_loss = epoch_loss / len(train_set)
- self.callback_loss(epoch_loss, i)
- loss_history.append(float(epoch_loss))
- LOGGER.info('epoch loss is {}'.format(epoch_loss))
-
- if client_agg is not None:
- if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0):
-
- self.model = client_agg.model_aggregation(self.model)
-
- converge_status = client_agg.loss_aggregation(epoch_loss)
- cur_agg_round += 1
- LOGGER.info(
- 'model averaging finished, aggregate round {}/{}'.format(
- cur_agg_round, aggregate_round))
- if converge_status:
- LOGGER.info('early stop triggered, stop training')
- need_stop = True
-
- if self.validation_freq and ((i + 1) % self.validation_freq == 0):
- LOGGER.info('running validation')
- ids_t, pred_t, label_t = self._predict(train_set)
- evaluation_summary = self.evaluation(
- ids_t,
- pred_t,
- label_t,
- dataset_type='train',
- epoch_idx=i,
- task_type=self.task_type)
- if validate_set is not None:
- ids_v, pred_v, label_v = self._predict(validate_set)
- evaluation_summary = self.evaluation(
- ids_v,
- pred_v,
- label_v,
- dataset_type='validate',
- epoch_idx=i,
- task_type=self.task_type)
-
- if self.save_freq is not None and ((i + 1) % self.save_freq == 0):
- self.checkpoint(
- i, self.model, optimizer, converge_status=need_stop, loss_history=loss_history)
- LOGGER.info('save checkpoint : epoch {}'.format(i))
-
- if need_stop:
- break
-
- best_epoch = int(np.array(loss_history).argmin())
- self.save(model=self.model, optimizer=optimizer, epoch_idx=cur_epoch, loss_history=loss_history,
- converge_status=need_stop, best_epoch=best_epoch)
- self.summary({
- 'best_epoch': best_epoch,
- 'loss_history': loss_history,
- 'need_stop': need_stop,
- 'metrics_summary': evaluation_summary
- })
- def _predict(self, dataset: Dataset):
- pred_result = []
-
- dataset.eval()
- self.model.eval()
- if not dataset.has_sample_ids():
- dataset.init_sid_and_getfunc(prefix=dataset.get_type())
- labels = []
- with torch.no_grad():
- for batch_data, batch_label in DataLoader(
- dataset, self.batch_size):
- if self.cuda:
- batch_data = self.to_cuda(batch_data)
- pred = self.model(batch_data)
- pred_result.append(pred)
- labels.append(batch_label)
- ret_rs = torch.concat(pred_result, axis=0)
- ret_label = torch.concat(labels, axis=0)
-
- dataset.train()
- self.model.train()
- return dataset.get_sample_ids(), ret_rs, ret_label
- def predict(self, dataset: Dataset):
- ids, ret_rs, ret_label = self._predict(dataset)
- if self.fed_mode:
- return self.format_predict_result(
- ids, ret_rs, ret_label, task_type=self.task_type)
- else:
- return ret_rs, ret_label
- def server_aggregate_procedure(self, extra_data={}):
-
- check_converge = False
- converge_func = None
- if self.early_stop:
- check_converge = True
- converge_func = converge_func_factory(
- self.early_stop, self.tol).is_converge
- LOGGER.info(
- 'check early stop, converge func is {}'.format(converge_func))
- LOGGER.info('server running aggregate procedure')
- server_agg = SecureAggServer(True, communicate_match_suffix=self.comm_suffix)
-
- for i in range(self.epochs):
- if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0):
-
- server_agg.model_aggregation()
- converge_status = False
-
- agg_loss, converge_status = server_agg.loss_aggregation(
- check_converge=check_converge, converge_func=converge_func)
- self.callback_loss(agg_loss, i)
-
- if self.save_freq is not None and ((i + 1) % self.save_freq == 0):
- self.checkpoint(epoch_idx=i)
- LOGGER.info('save checkpoint : epoch {}'.format(i))
-
- if converge_status:
- LOGGER.debug('stop triggered, stop aggregation')
- break
- LOGGER.info('server aggregation process done')
|