evaluators.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from __future__ import print_function, absolute_import
  2. import logging
  3. import os
  4. import torch
  5. from torch.backends import cudnn
  6. from .evaluation_metrics import cmc, mean_ap
  7. from .feature_extraction import extract_cnn_feature
  8. logger = logging.getLogger(__name__)
  9. # extract features for fed transform format
  10. def extract_features(model, data_loader, device, print_freq=1, metric=None):
  11. cudnn.benchmark = False
  12. model.eval()
  13. features = []
  14. logger.info("extracting features...")
  15. for i, (inputs, targets) in enumerate(data_loader):
  16. inputs = inputs.to(device)
  17. _fcs, pool5s = extract_cnn_feature(model, inputs)
  18. features.extend(pool5s)
  19. return features
  20. def pairwise_distance(query_features, gallery_features, metric=None):
  21. x = torch.cat([f.unsqueeze(0) for f in query_features], 0)
  22. y = torch.cat([f.unsqueeze(0) for f in gallery_features], 0)
  23. m, n = x.size(0), y.size(0)
  24. x = x.view(m, -1)
  25. y = y.view(n, -1)
  26. if metric is not None:
  27. x = metric.transform(x)
  28. y = metric.transform(y)
  29. dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
  30. torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
  31. dist.addmm_(1, -2, x, y.t())
  32. return dist
  33. def evaluate_all(distmat, query_ids, gallery_ids, query_cams, gallery_cams, cmc_topk=(1, 5, 10, 20)):
  34. # Compute mean AP
  35. mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams)
  36. # Compute all kinds of CMC scores
  37. cmc_configs = {
  38. 'market1501': dict(separate_camera_set=False,
  39. single_gallery_shot=False,
  40. first_match_break=True)}
  41. cmc_scores = {name: cmc(distmat, query_ids, gallery_ids,
  42. query_cams, gallery_cams, **params)
  43. for name, params in cmc_configs.items()}
  44. print('Mean AP: {:4.2%}'.format(mAP))
  45. print('CMC Scores:')
  46. for k in cmc_topk:
  47. print(' top-{:<4}{:12.2%}'
  48. .format(k,
  49. cmc_scores['market1501'][k - 1]))
  50. # Use the allshots cmc top-1 score for validation criterion
  51. return cmc_scores['market1501'][0], cmc_scores['market1501'][4], cmc_scores['market1501'][9], mAP
  52. class Evaluator(object):
  53. def __init__(self, model, test_data, query_id, gallery_id, device, is_print=False):
  54. super(Evaluator, self).__init__()
  55. self.model = model
  56. self.test_data = test_data
  57. self.device = device
  58. gallery_path = [(self.test_data.data[gallery_id]['x'][i],
  59. self.test_data.data[gallery_id]['y'][i])
  60. for i in range(len(self.test_data.data[gallery_id]['y']))]
  61. query_path = [(self.test_data.data[query_id]['x'][i],
  62. self.test_data.data[query_id]['y'][i])
  63. for i in range(len(self.test_data.data[query_id]['y']))]
  64. gallery_cam, gallery_label = get_id(gallery_path)
  65. self.gallery_cam = gallery_cam
  66. self.gallery_label = gallery_label
  67. query_cam, query_label = get_id(query_path)
  68. self.query_cam = query_cam
  69. self.query_label = query_label
  70. def evaluate(self, query_loader, gallery_loader, metric=None):
  71. query_features = extract_features(self.model, query_loader, self.device)
  72. gallery_features = extract_features(self.model, gallery_loader, self.device)
  73. distmat = pairwise_distance(query_features, gallery_features, metric=metric)
  74. return evaluate_all(distmat, self.query_label, self.gallery_label, self.query_cam, self.gallery_cam)
  75. def get_id(img_path):
  76. camera_id = []
  77. labels = []
  78. for p, v in img_path:
  79. filename = os.path.basename(p)
  80. if filename[:3] != 'cam':
  81. label = filename[0:4]
  82. camera = filename.split('c')[1]
  83. camera = camera.split('s')[0]
  84. else:
  85. label = filename.split('_')[2]
  86. camera = filename.split('_')[1]
  87. if label[0:2] == '-1':
  88. labels.append(-1)
  89. else:
  90. labels.append(int(label))
  91. camera_id.append(int(camera[0]))
  92. return camera_id, labels