123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- from __future__ import print_function, absolute_import
- import logging
- import time
- from torch.autograd import Variable
- from .evaluation_metrics import accuracy
- from .utils.meters import AverageMeter
- logger = logging.getLogger(__name__)
- class BaseTrainer(object):
- def __init__(self, model, criterion, device, fixed_layer=False):
- super(BaseTrainer, self).__init__()
- self.model = model
- self.criterion = criterion
- self.fixed_layer = fixed_layer
- self.device = device
- def train(self, epoch, data_loader, optimizer, print_freq=1):
- self.model.train()
- batch_time = AverageMeter()
- data_time = AverageMeter()
- losses = AverageMeter()
- precisions = AverageMeter()
- stop_local_training = False
- precision_avg = []
- end = time.time()
- for i, inputs in enumerate(data_loader):
- data_time.update(time.time() - end)
- inputs, targets = self._parse_data(inputs)
- loss, prec1 = self._forward(inputs, targets)
- losses.update(loss.item(), targets.size(0))
- precisions.update(prec1, targets.size(0))
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- batch_time.update(time.time() - end)
- end = time.time()
- if (i + 1) % print_freq == 0:
- logger.info('Epoch: [{}][{}/{}]\t'
- 'Time {:.3f} ({:.3f})\t'
- 'Data {:.3f} ({:.3f})\t'
- 'Loss {:.3f} ({:.3f})\t'
- 'Prec {:.2%} ({:.2%})\t'
- .format(epoch, i + 1, len(data_loader),
- batch_time.val, batch_time.avg,
- data_time.val, data_time.avg,
- losses.val, losses.avg,
- precisions.val, precisions.avg))
- precision_avg.append(precisions.avg)
- if precisions.val == 1 or precisions.avg > 0.95:
- stop_local_training = True
- return stop_local_training
- def _parse_data(self, inputs):
- raise NotImplementedError
- def _forward(self, inputs, targets):
- raise NotImplementedError
- class Trainer(BaseTrainer):
- def _parse_data(self, inputs):
- x, y = inputs
- inputs = Variable(x.to(self.device), requires_grad=False)
- targets = Variable(y.to(self.device))
- return inputs, targets
- def _forward(self, inputs, targets):
- outputs, _ = self.model(inputs)
- outputs = outputs.to(self.device)
- loss, outputs = self.criterion(outputs, targets)
- prec, = accuracy(outputs.data, targets.data)
- prec = prec[0]
- return loss, prec
|