123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364 |
- import copy
- import logging
- import time
- from collections import defaultdict
- import scipy.stats
- import torch
- from utils import AverageMeter
- from easyfl.distributed.distributed import CPU
- logger = logging.getLogger(__name__)
- LR_POLY = "poly"
- LR_CUSTOM = "custom"
- class Trainer:
- def __init__(self, cid, conf, train_loader, model, optimizer, criteria, device=CPU, checkpoint=None):
- self.cid = cid
- self.conf = conf
- self.train_loader = train_loader
- self.model = model
- self.optimizer = optimizer
- self.criteria = criteria
- self.loss_keys = list(self.criteria.keys())[1:]
- self.device = device
-
- self.progress_table = []
-
- self.stats = []
- self.start_epoch = 0
- self.loss_history = []
- self.encoder_trainable = None
-
- if checkpoint:
- if 'progress_table' in checkpoint:
- self.progress_table = checkpoint['progress_table']
- if 'epoch' in checkpoint:
- self.start_epoch = checkpoint['epoch'] + 1
-
-
- if 'stats' in checkpoint:
- self.stats = checkpoint['stats']
- if 'loss_history' in checkpoint:
- self.loss_history = checkpoint['loss_history']
- self.lr0 = self.conf.optimizer.lr
- self.lr = self.lr0
- self.ticks = 0
- self.last_tick = 0
-
-
- if self.conf.optimizer.lr_type == LR_CUSTOM:
- self.loss_tracking_window = len(train_loader) * self.conf.batch_size / 8
- self.maximum_loss_tracking_window = len(train_loader) * self.conf.batch_size / 2
- logger.info(
- f"Client {self.cid}: loss_tracking_window: {self.loss_tracking_window}, maximum_loss_tracking_window: {self.maximum_loss_tracking_window}")
- def train(self):
- self.encoder_trainable = [
- p for p in self.model.encoder.parameters() if p.requires_grad
- ]
- transference = {combined_task: [] for combined_task in self.loss_keys}
- for self.epoch in range(self.start_epoch, self.conf.local_epoch):
- current_learning_rate = get_average_learning_rate(self.optimizer)
-
- if current_learning_rate < self.conf.minimum_learning_rate:
- logger.info(f"Client {self.cid} stop local training because lr too small, lr: {current_learning_rate}.")
- break
-
- train_string, train_stats, epoch_transference = self.train_epoch()
- self.progress_table.append(train_string)
- self.stats.append(train_stats)
- for combined_task in self.loss_keys:
- transference[combined_task].append(epoch_transference[combined_task])
-
-
-
-
-
-
-
- del self.encoder_trainable
- self.encoder_trainable = None
- return transference
- def train_epoch(self):
- average_meters = defaultdict(AverageMeter)
- display_values = []
- for name, func in self.criteria.items():
- display_values.append(name)
-
- self.model.train()
- epoch_start_time = time.time()
- epoch_start_time2 = time.time()
- batch_num = 0
- num_data_points = len(self.train_loader) // self.conf.virtual_batch_multiplier
- if num_data_points > 10000:
- num_data_points = num_data_points // 5
- starting_learning_rate = get_average_learning_rate(self.optimizer)
-
- epoch_transference = {}
- for combined_task in self.loss_keys:
- epoch_transference[combined_task] = {}
- for recipient_task in self.loss_keys:
- epoch_transference[combined_task][recipient_task] = 0.
- for i, (input, target) in enumerate(self.train_loader):
- input = input.to(self.device)
- for n, t in target.items():
- target[n] = t.to(self.device)
-
- if i == 0:
- epoch_start_time2 = time.time()
- loss_dict = None
- loss = 0
-
- self.optimizer.zero_grad()
- _train_batch_lookahead = self.conf.lookahead == 'y' and i % self.conf.lookahead_step == 0
-
- for _ in range(self.conf.virtual_batch_multiplier):
- data_start = time.time()
- average_meters['data_time'].update(time.time() - data_start)
-
- if _train_batch_lookahead:
- loss_dict2, loss2, batch_transference = self.train_batch_lookahead(input, target)
- else:
- loss_dict2, loss2, batch_transference = self.train_batch(input, target)
- loss += loss2
- if loss_dict is None:
- loss_dict = loss_dict2
- else:
- for key, value in loss_dict2.items():
- loss_dict[key] += value
- if _train_batch_lookahead:
- for combined_task in self.loss_keys:
- for recipient_task in self.loss_keys:
- epoch_transference[combined_task][recipient_task] += (
- batch_transference[combined_task][recipient_task] / (len(self.train_loader) / self.conf.lookahead_step))
-
- loss /= self.conf.virtual_batch_multiplier
- for key, value in loss_dict.items():
- loss_dict[key] = value / self.conf.virtual_batch_multiplier
-
- self.optimizer.step()
- self.loss_history.append(float(loss))
- ttest_p, z_diff = self.learning_rate_schedule()
- for name, value in loss_dict.items():
- try:
- average_meters[name].update(value.data)
- except:
- average_meters[name].update(value)
- elapsed_time_for_epoch = (time.time() - epoch_start_time2)
- eta = (elapsed_time_for_epoch / (batch_num + .2)) * (num_data_points - batch_num)
- if eta >= 24 * 3600:
- eta = 24 * 3600 - 1
- batch_num += 1
- current_learning_rate = get_average_learning_rate(self.optimizer)
- to_print = {
- 'ep': f'{self.epoch}:',
- f'#/{num_data_points}': f'{batch_num}',
- 'lr': '{0:0.3g}-{1:0.3g}'.format(starting_learning_rate, current_learning_rate),
- 'eta': '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(eta)))),
- 'd%': '{0:0.2g}'.format(100 * average_meters['data_time'].sum / elapsed_time_for_epoch)
- }
- for name in display_values:
- meter = average_meters[name]
- to_print[name] = '{meter.avg:.4f}'.format(meter=meter)
- if batch_num < num_data_points - 1:
- to_print['ETA'] = '{0}'.format(
- time.strftime("%H:%M:%S", time.gmtime(int(eta + elapsed_time_for_epoch))))
- to_print['ttest'] = '{0:0.3g},{1:0.3g}'.format(z_diff, ttest_p)
- epoch_time = time.time() - epoch_start_time
- stats = {
- 'batches': num_data_points,
- 'learning_rate': current_learning_rate,
- 'epoch_time': epoch_time,
- }
- for name in display_values:
- meter = average_meters[name]
- stats[name] = meter.avg
- to_print['eta'] = '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(epoch_time))))
- logger.info(f"Client {self.cid} training statistics: {stats}")
- return [to_print], stats, epoch_transference
- def train_batch(self, x, target):
- loss_dict = {}
- x = x.float()
- output = self.model(x)
- first_loss = None
- for c_name, criterion_fn in self.criteria.items():
- if first_loss is None:
- first_loss = c_name
- loss_dict[c_name] = criterion_fn(output, target)
- loss = loss_dict[first_loss].clone()
- loss = loss / self.conf.virtual_batch_multiplier
- if self.conf.fp16:
- from apex import amp
- with amp.scale_loss(loss, self.optimizer) as scaled_loss:
- scaled_loss.backward()
- else:
- loss.backward()
- return loss_dict, loss, {}
- def train_batch_lookahead(self, x, target):
- loss_dict = {}
- x = x.float()
- output = self.model(x)
- first_loss = None
- for c_name, criterion_fun in self.criteria.items():
- if first_loss is None:
- first_loss = c_name
- loss_dict[c_name] = criterion_fun(output, target)
- loss = loss_dict[first_loss].clone()
- transference = {}
- for combined_task in self.loss_keys:
- transference[combined_task] = {}
- if self.conf.fp16:
- from apex import amp
- with amp.scale_loss(loss, self.optimizer) as scaled_loss:
- scaled_loss.backward()
- else:
- for combined_task in self.loss_keys:
- preds = self.lookahead(x, loss_dict[combined_task])
- first_loss = None
- for c_name, criterion_fun in self.criteria.items():
- if first_loss is None:
- first_loss = c_name
- transference[combined_task][c_name] = (
- (1.0 - (criterion_fun(preds, target) / loss_dict[c_name])) /
- self.optimizer.state_dict()['param_groups'][0]['lr']
- ).detach().cpu().numpy()
- self.optimizer.zero_grad()
- loss.backward()
-
- rev_transference = {source: {} for source in transference}
- for grad_task in transference:
- for source in transference[grad_task]:
- if 'Loss' in source:
- continue
- rev_transference[source][grad_task] = transference[grad_task][
- source]
- return loss_dict, loss, copy.deepcopy(rev_transference)
- def lookahead(self, x, loss):
- self.optimizer.zero_grad()
- shared_params = self.encoder_trainable
- init_weights = [param.data for param in shared_params]
- grads = torch.autograd.grad(loss, shared_params, retain_graph=True)
-
- opt_state = self.optimizer.state_dict()['param_groups'][0]
- weight_decay = opt_state['weight_decay']
- for param, g, param_id in zip(shared_params, grads, opt_state['params']):
- grad = g.clone()
- grad += param * weight_decay
- if 'momentum_buffer' not in opt_state:
- mom_buf = grad
- else:
- mom_buf = opt_state['momentum_buffer']
- mom_buf = mom_buf * opt_state['momentum'] + grad
- param.data = param.data - opt_state['lr'] * mom_buf
- grad = grad.cpu()
- del grad
- with torch.no_grad():
- output = self.model(x)
- for param, init_weight in zip(shared_params, init_weights):
- param.data = init_weight
- return output
- def learning_rate_schedule(self):
-
- if self.conf.optimizer.lr_type == LR_POLY:
- return 0, 0
-
- if self.epoch < 2:
- return 0, 0
- ttest_p = 0
- z_diff = 0
- wind = self.loss_tracking_window // (self.conf.batch_size * self.conf.virtual_batch_multiplier)
- if len(self.loss_history) - self.last_tick > wind:
- a = self.loss_history[-wind:-wind * 5 // 8]
- b = self.loss_history[-wind * 3 // 8:]
-
- a = sorted(a)
- b = sorted(b)
- a = a[int(len(a) * .05):int(len(a) * .95)]
- b = b[int(len(b) * .05):int(len(b) * .95)]
- length_ = min(len(a), len(b))
- a = a[:length_]
- b = b[:length_]
- z_diff, ttest_p = scipy.stats.ttest_rel(a, b, nan_policy='omit')
- if z_diff < 0 or ttest_p > .99:
- self.ticks += 1
- self.last_tick = len(self.loss_history)
- self.adjust_learning_rate()
- self.loss_tracking_window = min(self.maximum_loss_tracking_window, self.loss_tracking_window * 2)
- return ttest_p, z_diff
- def adjust_learning_rate(self):
- self.lr = self.lr0 * (0.50 ** self.ticks)
- self.set_learning_rate(self.lr)
- def set_learning_rate(self, lr):
- for param_group in self.optimizer.param_groups:
- param_group['lr'] = lr
- def update(self, model, optimizer, device):
- self.model = model
- self.optimizer = optimizer
- self.device = device
- def get_average_learning_rate(optimizer):
- try:
- return optimizer.learning_rate
- except:
- s = 0
- for param_group in optimizer.param_groups:
- s += param_group['lr']
- return s / len(optimizer.param_groups)
|