trainer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import copy
  2. import logging
  3. import time
  4. from collections import defaultdict
  5. import scipy.stats
  6. import torch
  7. from utils import AverageMeter
  8. from easyfl.distributed.distributed import CPU
  9. logger = logging.getLogger(__name__)
  10. LR_POLY = "poly"
  11. LR_CUSTOM = "custom"
  12. class Trainer:
  13. def __init__(self, cid, conf, train_loader, model, optimizer, criteria, device=CPU, checkpoint=None):
  14. self.cid = cid
  15. self.conf = conf
  16. self.train_loader = train_loader
  17. self.model = model
  18. self.optimizer = optimizer
  19. self.criteria = criteria
  20. self.loss_keys = list(self.criteria.keys())[1:]
  21. self.device = device
  22. # self.args = args
  23. self.progress_table = []
  24. # self.best_loss = 9e9
  25. self.stats = []
  26. self.start_epoch = 0
  27. self.loss_history = []
  28. self.encoder_trainable = None
  29. # self.code_archive = self.get_code_archive()
  30. if checkpoint:
  31. if 'progress_table' in checkpoint:
  32. self.progress_table = checkpoint['progress_table']
  33. if 'epoch' in checkpoint:
  34. self.start_epoch = checkpoint['epoch'] + 1
  35. # if 'best_loss' in checkpoint:
  36. # self.best_loss = checkpoint['best_loss']
  37. if 'stats' in checkpoint:
  38. self.stats = checkpoint['stats']
  39. if 'loss_history' in checkpoint:
  40. self.loss_history = checkpoint['loss_history']
  41. self.lr0 = self.conf.optimizer.lr
  42. self.lr = self.lr0
  43. self.ticks = 0
  44. self.last_tick = 0
  45. # self.loss_tracking_window = self.conf.loss_tracking_window_initial
  46. # estimated loss tracking window for each client, based on their dataset size, compared with original implementation.
  47. if self.conf.optimizer.lr_type == LR_CUSTOM:
  48. self.loss_tracking_window = len(train_loader) * self.conf.batch_size / 8
  49. self.maximum_loss_tracking_window = len(train_loader) * self.conf.batch_size / 2
  50. logger.info(
  51. f"Client {self.cid}: loss_tracking_window: {self.loss_tracking_window}, maximum_loss_tracking_window: {self.maximum_loss_tracking_window}")
  52. def train(self):
  53. self.encoder_trainable = [
  54. p for p in self.model.encoder.parameters() if p.requires_grad
  55. ]
  56. transference = {combined_task: [] for combined_task in self.loss_keys}
  57. for self.epoch in range(self.start_epoch, self.conf.local_epoch):
  58. current_learning_rate = get_average_learning_rate(self.optimizer)
  59. # Stop training when learning rate is smaller than minimum learning rate
  60. if current_learning_rate < self.conf.minimum_learning_rate:
  61. logger.info(f"Client {self.cid} stop local training because lr too small, lr: {current_learning_rate}.")
  62. break
  63. # Train for one epoch
  64. train_string, train_stats, epoch_transference = self.train_epoch()
  65. self.progress_table.append(train_string)
  66. self.stats.append(train_stats)
  67. for combined_task in self.loss_keys:
  68. transference[combined_task].append(epoch_transference[combined_task])
  69. # # evaluate on validation set
  70. # progress_string = train_string
  71. # loss, progress_string, val_stats = self.validate(progress_string)
  72. #
  73. # self.progress_table.append(progress_string)
  74. # self.stats.append((train_stats, val_stats))
  75. # Clean up to save memory
  76. del self.encoder_trainable
  77. self.encoder_trainable = None
  78. return transference
  79. def train_epoch(self):
  80. average_meters = defaultdict(AverageMeter)
  81. display_values = []
  82. for name, func in self.criteria.items():
  83. display_values.append(name)
  84. # Switch to train mode
  85. self.model.train()
  86. epoch_start_time = time.time()
  87. epoch_start_time2 = time.time()
  88. batch_num = 0
  89. num_data_points = len(self.train_loader) // self.conf.virtual_batch_multiplier
  90. if num_data_points > 10000:
  91. num_data_points = num_data_points // 5
  92. starting_learning_rate = get_average_learning_rate(self.optimizer)
  93. # Initialize task affinity dictionary
  94. epoch_transference = {}
  95. for combined_task in self.loss_keys:
  96. epoch_transference[combined_task] = {}
  97. for recipient_task in self.loss_keys:
  98. epoch_transference[combined_task][recipient_task] = 0.
  99. for i, (input, target) in enumerate(self.train_loader):
  100. input = input.to(self.device)
  101. for n, t in target.items():
  102. target[n] = t.to(self.device)
  103. # self.percent = batch_num / num_data_points
  104. if i == 0:
  105. epoch_start_time2 = time.time()
  106. loss_dict = None
  107. loss = 0
  108. self.optimizer.zero_grad()
  109. _train_batch_lookahead = self.conf.lookahead == 'y' and i % self.conf.lookahead_step == 0
  110. # Accumulate gradients over multiple runs of input
  111. for _ in range(self.conf.virtual_batch_multiplier):
  112. data_start = time.time()
  113. average_meters['data_time'].update(time.time() - data_start)
  114. # lookahead step 10
  115. if _train_batch_lookahead:
  116. loss_dict2, loss2, batch_transference = self.train_batch_lookahead(input, target)
  117. else:
  118. loss_dict2, loss2, batch_transference = self.train_batch(input, target)
  119. loss += loss2
  120. if loss_dict is None:
  121. loss_dict = loss_dict2
  122. else:
  123. for key, value in loss_dict2.items():
  124. loss_dict[key] += value
  125. if _train_batch_lookahead:
  126. for combined_task in self.loss_keys:
  127. for recipient_task in self.loss_keys:
  128. epoch_transference[combined_task][recipient_task] += (
  129. batch_transference[combined_task][recipient_task] / (len(self.train_loader) / self.conf.lookahead_step))
  130. # divide by the number of accumulations
  131. loss /= self.conf.virtual_batch_multiplier
  132. for key, value in loss_dict.items():
  133. loss_dict[key] = value / self.conf.virtual_batch_multiplier
  134. # do the weight updates and set gradients back to zero
  135. self.optimizer.step()
  136. self.loss_history.append(float(loss))
  137. ttest_p, z_diff = self.learning_rate_schedule()
  138. for name, value in loss_dict.items():
  139. try:
  140. average_meters[name].update(value.data)
  141. except:
  142. average_meters[name].update(value)
  143. elapsed_time_for_epoch = (time.time() - epoch_start_time2)
  144. eta = (elapsed_time_for_epoch / (batch_num + .2)) * (num_data_points - batch_num)
  145. if eta >= 24 * 3600:
  146. eta = 24 * 3600 - 1
  147. batch_num += 1
  148. current_learning_rate = get_average_learning_rate(self.optimizer)
  149. to_print = {
  150. 'ep': f'{self.epoch}:',
  151. f'#/{num_data_points}': f'{batch_num}',
  152. 'lr': '{0:0.3g}-{1:0.3g}'.format(starting_learning_rate, current_learning_rate),
  153. 'eta': '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(eta)))),
  154. 'd%': '{0:0.2g}'.format(100 * average_meters['data_time'].sum / elapsed_time_for_epoch)
  155. }
  156. for name in display_values:
  157. meter = average_meters[name]
  158. to_print[name] = '{meter.avg:.4f}'.format(meter=meter)
  159. if batch_num < num_data_points - 1:
  160. to_print['ETA'] = '{0}'.format(
  161. time.strftime("%H:%M:%S", time.gmtime(int(eta + elapsed_time_for_epoch))))
  162. to_print['ttest'] = '{0:0.3g},{1:0.3g}'.format(z_diff, ttest_p)
  163. epoch_time = time.time() - epoch_start_time
  164. stats = {
  165. 'batches': num_data_points,
  166. 'learning_rate': current_learning_rate,
  167. 'epoch_time': epoch_time,
  168. }
  169. for name in display_values:
  170. meter = average_meters[name]
  171. stats[name] = meter.avg
  172. to_print['eta'] = '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(epoch_time))))
  173. logger.info(f"Client {self.cid} training statistics: {stats}")
  174. return [to_print], stats, epoch_transference
  175. def train_batch(self, x, target):
  176. loss_dict = {}
  177. x = x.float()
  178. output = self.model(x)
  179. first_loss = None
  180. for c_name, criterion_fn in self.criteria.items():
  181. if first_loss is None:
  182. first_loss = c_name
  183. loss_dict[c_name] = criterion_fn(output, target)
  184. loss = loss_dict[first_loss].clone()
  185. loss = loss / self.conf.virtual_batch_multiplier
  186. if self.conf.fp16:
  187. from apex import amp
  188. with amp.scale_loss(loss, self.optimizer) as scaled_loss:
  189. scaled_loss.backward()
  190. else:
  191. loss.backward()
  192. return loss_dict, loss, {}
  193. def train_batch_lookahead(self, x, target):
  194. loss_dict = {}
  195. x = x.float()
  196. output = self.model(x)
  197. first_loss = None
  198. for c_name, criterion_fun in self.criteria.items():
  199. if first_loss is None:
  200. first_loss = c_name
  201. loss_dict[c_name] = criterion_fun(output, target)
  202. loss = loss_dict[first_loss].clone()
  203. transference = {}
  204. for combined_task in self.loss_keys:
  205. transference[combined_task] = {}
  206. if self.conf.fp16:
  207. from apex import amp
  208. with amp.scale_loss(loss, self.optimizer) as scaled_loss:
  209. scaled_loss.backward()
  210. else:
  211. for combined_task in self.loss_keys:
  212. preds = self.lookahead(x, loss_dict[combined_task])
  213. first_loss = None
  214. for c_name, criterion_fun in self.criteria.items():
  215. if first_loss is None:
  216. first_loss = c_name
  217. transference[combined_task][c_name] = (
  218. (1.0 - (criterion_fun(preds, target) / loss_dict[c_name])) /
  219. self.optimizer.state_dict()['param_groups'][0]['lr']
  220. ).detach().cpu().numpy()
  221. self.optimizer.zero_grad()
  222. loss.backward()
  223. # Want to invert the dictionary so it's source_task => gradients on source task.
  224. rev_transference = {source: {} for source in transference}
  225. for grad_task in transference:
  226. for source in transference[grad_task]:
  227. if 'Loss' in source:
  228. continue
  229. rev_transference[source][grad_task] = transference[grad_task][
  230. source]
  231. return loss_dict, loss, copy.deepcopy(rev_transference)
  232. def lookahead(self, x, loss):
  233. self.optimizer.zero_grad()
  234. shared_params = self.encoder_trainable
  235. init_weights = [param.data for param in shared_params]
  236. grads = torch.autograd.grad(loss, shared_params, retain_graph=True)
  237. # Compute updated params for the forward pass: SGD w/ 0.9 momentum + 1e-4 weight decay.
  238. opt_state = self.optimizer.state_dict()['param_groups'][0]
  239. weight_decay = opt_state['weight_decay']
  240. for param, g, param_id in zip(shared_params, grads, opt_state['params']):
  241. grad = g.clone()
  242. grad += param * weight_decay
  243. if 'momentum_buffer' not in opt_state:
  244. mom_buf = grad
  245. else:
  246. mom_buf = opt_state['momentum_buffer']
  247. mom_buf = mom_buf * opt_state['momentum'] + grad
  248. param.data = param.data - opt_state['lr'] * mom_buf
  249. grad = grad.cpu()
  250. del grad
  251. with torch.no_grad():
  252. output = self.model(x)
  253. for param, init_weight in zip(shared_params, init_weights):
  254. param.data = init_weight
  255. return output
  256. def learning_rate_schedule(self):
  257. # don't process learning rate if the schedule type is poly, which adjusted before training.
  258. if self.conf.optimizer.lr_type == LR_POLY:
  259. return 0, 0
  260. # don't reduce learning rate until the second epoch has ended.
  261. if self.epoch < 2:
  262. return 0, 0
  263. ttest_p = 0
  264. z_diff = 0
  265. wind = self.loss_tracking_window // (self.conf.batch_size * self.conf.virtual_batch_multiplier)
  266. if len(self.loss_history) - self.last_tick > wind:
  267. a = self.loss_history[-wind:-wind * 5 // 8]
  268. b = self.loss_history[-wind * 3 // 8:]
  269. # remove outliers
  270. a = sorted(a)
  271. b = sorted(b)
  272. a = a[int(len(a) * .05):int(len(a) * .95)]
  273. b = b[int(len(b) * .05):int(len(b) * .95)]
  274. length_ = min(len(a), len(b))
  275. a = a[:length_]
  276. b = b[:length_]
  277. z_diff, ttest_p = scipy.stats.ttest_rel(a, b, nan_policy='omit')
  278. if z_diff < 0 or ttest_p > .99:
  279. self.ticks += 1
  280. self.last_tick = len(self.loss_history)
  281. self.adjust_learning_rate()
  282. self.loss_tracking_window = min(self.maximum_loss_tracking_window, self.loss_tracking_window * 2)
  283. return ttest_p, z_diff
  284. def adjust_learning_rate(self):
  285. self.lr = self.lr0 * (0.50 ** self.ticks)
  286. self.set_learning_rate(self.lr)
  287. def set_learning_rate(self, lr):
  288. for param_group in self.optimizer.param_groups:
  289. param_group['lr'] = lr
  290. def update(self, model, optimizer, device):
  291. self.model = model
  292. self.optimizer = optimizer
  293. self.device = device
  294. def get_average_learning_rate(optimizer):
  295. try:
  296. return optimizer.learning_rate
  297. except:
  298. s = 0
  299. for param_group in optimizer.param_groups:
  300. s += param_group['lr']
  301. return s / len(optimizer.param_groups)