trainers.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from __future__ import print_function, absolute_import
  2. import logging
  3. import time
  4. from torch.autograd import Variable
  5. from .evaluation_metrics import accuracy
  6. from .utils.meters import AverageMeter
  7. logger = logging.getLogger(__name__)
  8. class BaseTrainer(object):
  9. def __init__(self, model, criterion, device, fixed_layer=False):
  10. super(BaseTrainer, self).__init__()
  11. self.model = model
  12. self.criterion = criterion
  13. self.fixed_layer = fixed_layer
  14. self.device = device
  15. def train(self, epoch, data_loader, optimizer, print_freq=1):
  16. self.model.train()
  17. batch_time = AverageMeter()
  18. data_time = AverageMeter()
  19. losses = AverageMeter()
  20. precisions = AverageMeter()
  21. stop_local_training = False
  22. precision_avg = []
  23. end = time.time()
  24. for i, inputs in enumerate(data_loader):
  25. data_time.update(time.time() - end)
  26. inputs, targets = self._parse_data(inputs)
  27. loss, prec1 = self._forward(inputs, targets)
  28. losses.update(loss.item(), targets.size(0))
  29. precisions.update(prec1, targets.size(0))
  30. optimizer.zero_grad()
  31. loss.backward()
  32. optimizer.step()
  33. batch_time.update(time.time() - end)
  34. end = time.time()
  35. if (i + 1) % print_freq == 0:
  36. logger.info('Epoch: [{}][{}/{}]\t'
  37. 'Time {:.3f} ({:.3f})\t'
  38. 'Data {:.3f} ({:.3f})\t'
  39. 'Loss {:.3f} ({:.3f})\t'
  40. 'Prec {:.2%} ({:.2%})\t'
  41. .format(epoch, i + 1, len(data_loader),
  42. batch_time.val, batch_time.avg,
  43. data_time.val, data_time.avg,
  44. losses.val, losses.avg,
  45. precisions.val, precisions.avg))
  46. precision_avg.append(precisions.avg)
  47. if precisions.val == 1 or precisions.avg > 0.95:
  48. stop_local_training = True
  49. return stop_local_training
  50. def _parse_data(self, inputs):
  51. raise NotImplementedError
  52. def _forward(self, inputs, targets):
  53. raise NotImplementedError
  54. class Trainer(BaseTrainer):
  55. def _parse_data(self, inputs):
  56. x, y = inputs
  57. inputs = Variable(x.to(self.device), requires_grad=False)
  58. targets = Variable(y.to(self.device))
  59. return inputs, targets
  60. def _forward(self, inputs, targets):
  61. outputs, _ = self.model(inputs)
  62. outputs = outputs.to(self.device)
  63. loss, outputs = self.criterion(outputs, targets)
  64. prec, = accuracy(outputs.data, targets.data)
  65. prec = prec[0]
  66. return loss, prec