fedavg_trainer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import torch
  2. import torch as t
  3. import tqdm
  4. import numpy as np
  5. from torch.utils.data import DataLoader
  6. from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient as SecureAggClient
  7. from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorServer as SecureAggServer
  8. from federatedml.nn.dataset.base import Dataset
  9. from federatedml.nn.homo.trainer.trainer_base import TrainerBase
  10. from federatedml.util import LOGGER, consts
  11. from federatedml.optim.convergence import converge_func_factory
  12. class FedAVGTrainer(TrainerBase):
  13. """
  14. Parameters
  15. ----------
  16. epochs: int >0, epochs to train
  17. batch_size: int, -1 means full batch
  18. secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number
  19. mask to local models. These random number masks will eventually cancel out to get 0.
  20. weighted_aggregation: bool, whether add weight to each local model when doing aggregation.
  21. if True, According to origin paper, weight of a client is: n_local / n_global, where n_local
  22. is the sample number locally and n_global is the sample number of all clients.
  23. if False, simply averaging these models.
  24. early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between
  25. two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,
  26. stop training
  27. tol: float, tol value for early stop
  28. aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate
  29. every n epochs.
  30. cuda: bool, use cuda or not
  31. pin_memory: bool, for pytorch DataLoader
  32. shuffle: bool, for pytorch DataLoader
  33. data_loader_worker: int, for pytorch DataLoader, number of workers when loading data
  34. validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.
  35. if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'
  36. if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'
  37. if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'
  38. checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.
  39. task_type: str, 'auto', 'binary', 'multi', 'regression'
  40. this option decides the return format of this trainer, and the evaluation type when running validation.
  41. if auto, will automatically infer your task type from labels and predict results.
  42. """
  43. def __init__(self, epochs=10, batch_size=512, # training parameter
  44. early_stop=None, tol=0.0001, # early stop parameters
  45. secure_aggregate=True, weighted_aggregation=True, aggregate_every_n_epoch=None, # federation
  46. cuda=False, pin_memory=True, shuffle=True, data_loader_worker=0, # GPU & dataloader
  47. validation_freqs=None, # validation configuration
  48. checkpoint_save_freqs=None, # checkpoint configuration
  49. task_type='auto'
  50. ):
  51. super(FedAVGTrainer, self).__init__()
  52. # training parameters
  53. self.epochs = epochs
  54. self.tol = tol
  55. self.validation_freq = validation_freqs
  56. self.save_freq = checkpoint_save_freqs
  57. self.task_type = task_type
  58. task_type_allow = [
  59. consts.BINARY,
  60. consts.REGRESSION,
  61. consts.MULTY,
  62. 'auto']
  63. assert self.task_type in task_type_allow, 'task type must in {}'.format(
  64. task_type_allow)
  65. # aggregation param
  66. self.secure_aggregate = secure_aggregate
  67. self.weighted_aggregation = weighted_aggregation
  68. self.aggregate_every_n_epoch = aggregate_every_n_epoch
  69. # GPU
  70. self.cuda = cuda
  71. if not torch.cuda.is_available() and self.cuda:
  72. raise ValueError('Cuda is not available on this machine')
  73. # data loader
  74. self.batch_size = batch_size
  75. self.pin_memory = pin_memory
  76. self.shuffle = shuffle
  77. self.data_loader_worker = data_loader_worker
  78. self.early_stop = early_stop
  79. early_stop_type = ['diff', 'abs']
  80. if early_stop is not None:
  81. assert early_stop in early_stop_type, 'early stop type must be in {}, bug got {}' \
  82. .format(early_stop_type, early_stop)
  83. # communicate suffix
  84. self.comm_suffix = 'fedavg'
  85. # check param correctness
  86. self.check_trainer_param([self.epochs,
  87. self.validation_freq,
  88. self.save_freq,
  89. self.aggregate_every_n_epoch],
  90. ['epochs',
  91. 'validation_freq',
  92. 'save_freq',
  93. 'aggregate_every_n_epoch'],
  94. self.is_pos_int,
  95. '{} is not a positive int')
  96. self.check_trainer_param([self.secure_aggregate, self.weighted_aggregation, self.pin_memory], [
  97. 'secure_aggregate', 'weighted_aggregation', 'pin_memory,'], self.is_bool, '{} is not a bool')
  98. self.check_trainer_param(
  99. [self.tol], ['tol'], self.is_float, '{} is not a float')
  100. def train(
  101. self,
  102. train_set: Dataset,
  103. validate_set: Dataset = None,
  104. optimizer: t.optim.Optimizer = None,
  105. loss=None,
  106. extra_dict={}):
  107. if self.cuda:
  108. self.model = self.model.cuda()
  109. if optimizer is None:
  110. raise ValueError(
  111. 'FedAVGTrainer requires an optimizer, but got None, please specify optimizer in the '
  112. 'job configuration')
  113. if loss is None:
  114. raise ValueError(
  115. 'FedAVGTrainer requires a loss function, but got None, please specify loss function in the'
  116. ' job configuration')
  117. if self.batch_size > len(train_set) or self.batch_size == -1:
  118. self.batch_size = len(train_set)
  119. dl = DataLoader(
  120. train_set,
  121. batch_size=self.batch_size,
  122. pin_memory=self.pin_memory,
  123. shuffle=self.shuffle,
  124. num_workers=self.data_loader_worker)
  125. # compute round to aggregate
  126. cur_agg_round = 0
  127. if self.aggregate_every_n_epoch is not None:
  128. aggregate_round = self.epochs // self.aggregate_every_n_epoch
  129. else:
  130. aggregate_round = self.epochs
  131. # initialize fed avg client
  132. if self.fed_mode:
  133. if self.weighted_aggregation:
  134. sample_num = len(train_set)
  135. else:
  136. sample_num = 1.0
  137. client_agg = SecureAggClient(
  138. True, aggregate_weight=sample_num, communicate_match_suffix=self.comm_suffix)
  139. else:
  140. client_agg = None
  141. # running var
  142. cur_epoch = 0
  143. loss_history = []
  144. need_stop = False
  145. evaluation_summary = {}
  146. # training process
  147. for i in range(self.epochs):
  148. cur_epoch = i
  149. LOGGER.info('epoch is {}'.format(i))
  150. epoch_loss = 0.0
  151. batch_idx = 0
  152. acc_num = 0
  153. # for better user interface
  154. if not self.fed_mode:
  155. to_iterate = tqdm.tqdm(dl)
  156. else:
  157. to_iterate = dl
  158. for batch_data, batch_label in to_iterate:
  159. if self.cuda:
  160. batch_data, batch_label = self.to_cuda(
  161. batch_data), self.to_cuda(batch_label)
  162. optimizer.zero_grad()
  163. pred = self.model(batch_data)
  164. batch_loss = loss(pred, batch_label)
  165. batch_loss.backward()
  166. optimizer.step()
  167. batch_loss_np = batch_loss.detach().numpy(
  168. ) if not self.cuda else batch_loss.cpu().detach().numpy()
  169. if acc_num + self.batch_size > len(train_set):
  170. batch_len = len(train_set) - acc_num
  171. else:
  172. batch_len = self.batch_size
  173. epoch_loss += batch_loss_np * batch_len
  174. batch_idx += 1
  175. if self.fed_mode:
  176. LOGGER.debug(
  177. 'epoch {} batch {} finished'.format(
  178. i, batch_idx))
  179. # loss compute
  180. epoch_loss = epoch_loss / len(train_set)
  181. self.callback_loss(epoch_loss, i)
  182. loss_history.append(float(epoch_loss))
  183. LOGGER.info('epoch loss is {}'.format(epoch_loss))
  184. # federation process, if running local mode, cancel federation
  185. if client_agg is not None:
  186. if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0):
  187. # model averaging
  188. self.model = client_agg.model_aggregation(self.model)
  189. # agg loss and get converge status
  190. converge_status = client_agg.loss_aggregation(epoch_loss)
  191. cur_agg_round += 1
  192. LOGGER.info(
  193. 'model averaging finished, aggregate round {}/{}'.format(
  194. cur_agg_round, aggregate_round))
  195. if converge_status:
  196. LOGGER.info('early stop triggered, stop training')
  197. need_stop = True
  198. # validation process
  199. if self.validation_freq and ((i + 1) % self.validation_freq == 0):
  200. LOGGER.info('running validation')
  201. ids_t, pred_t, label_t = self._predict(train_set)
  202. evaluation_summary = self.evaluation(
  203. ids_t,
  204. pred_t,
  205. label_t,
  206. dataset_type='train',
  207. epoch_idx=i,
  208. task_type=self.task_type)
  209. if validate_set is not None:
  210. ids_v, pred_v, label_v = self._predict(validate_set)
  211. evaluation_summary = self.evaluation(
  212. ids_v,
  213. pred_v,
  214. label_v,
  215. dataset_type='validate',
  216. epoch_idx=i,
  217. task_type=self.task_type)
  218. # save check point process
  219. if self.save_freq is not None and ((i + 1) % self.save_freq == 0):
  220. self.checkpoint(
  221. i, self.model, optimizer, converge_status=need_stop, loss_history=loss_history)
  222. LOGGER.info('save checkpoint : epoch {}'.format(i))
  223. # if meet stop condition then stop
  224. if need_stop:
  225. break
  226. # post-process
  227. best_epoch = int(np.array(loss_history).argmin())
  228. self.save(model=self.model, optimizer=optimizer, epoch_idx=cur_epoch, loss_history=loss_history,
  229. converge_status=need_stop, best_epoch=best_epoch)
  230. self.summary({
  231. 'best_epoch': best_epoch,
  232. 'loss_history': loss_history,
  233. 'need_stop': need_stop,
  234. 'metrics_summary': evaluation_summary
  235. })
  236. def _predict(self, dataset: Dataset):
  237. pred_result = []
  238. # switch eval mode
  239. dataset.eval()
  240. self.model.eval()
  241. if not dataset.has_sample_ids():
  242. dataset.init_sid_and_getfunc(prefix=dataset.get_type())
  243. labels = []
  244. with torch.no_grad():
  245. for batch_data, batch_label in DataLoader(
  246. dataset, self.batch_size):
  247. if self.cuda:
  248. batch_data = self.to_cuda(batch_data)
  249. pred = self.model(batch_data)
  250. pred_result.append(pred)
  251. labels.append(batch_label)
  252. ret_rs = torch.concat(pred_result, axis=0)
  253. ret_label = torch.concat(labels, axis=0)
  254. # switch back to train mode
  255. dataset.train()
  256. self.model.train()
  257. return dataset.get_sample_ids(), ret_rs, ret_label
  258. def predict(self, dataset: Dataset):
  259. ids, ret_rs, ret_label = self._predict(dataset)
  260. if self.fed_mode:
  261. return self.format_predict_result(
  262. ids, ret_rs, ret_label, task_type=self.task_type)
  263. else:
  264. return ret_rs, ret_label
  265. def server_aggregate_procedure(self, extra_data={}):
  266. # converge status
  267. check_converge = False
  268. converge_func = None
  269. if self.early_stop:
  270. check_converge = True
  271. converge_func = converge_func_factory(
  272. self.early_stop, self.tol).is_converge
  273. LOGGER.info(
  274. 'check early stop, converge func is {}'.format(converge_func))
  275. LOGGER.info('server running aggregate procedure')
  276. server_agg = SecureAggServer(True, communicate_match_suffix=self.comm_suffix)
  277. # aggregate and broadcast models
  278. for i in range(self.epochs):
  279. if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0):
  280. # model aggregate
  281. server_agg.model_aggregation()
  282. converge_status = False
  283. # loss aggregate
  284. agg_loss, converge_status = server_agg.loss_aggregation(
  285. check_converge=check_converge, converge_func=converge_func)
  286. self.callback_loss(agg_loss, i)
  287. # save check point process
  288. if self.save_freq is not None and ((i + 1) % self.save_freq == 0):
  289. self.checkpoint(epoch_idx=i)
  290. LOGGER.info('save checkpoint : epoch {}'.format(i))
  291. # check stop condition
  292. if converge_status:
  293. LOGGER.debug('stop triggered, stop aggregation')
  294. break
  295. LOGGER.info('server aggregation process done')