evaluate.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. def extract_feature(model, dataloaders, device, ms=[1]):
  6. features = torch.FloatTensor()
  7. model = model.to(device)
  8. for data in dataloaders:
  9. img, label = data
  10. n, c, h, w = img.size()
  11. ff = torch.FloatTensor(n, 512).zero_().to(device)
  12. for i in range(2):
  13. if i == 1:
  14. img = fliplr(img)
  15. input_img = Variable(img.to(device))
  16. for scale in ms:
  17. if scale != 1:
  18. # bicubic is only available in pytorch>= 1.1
  19. input_img = nn.functional.interpolate(input_img, scale_factor=scale, mode='bicubic',
  20. align_corners=False)
  21. outputs = model(input_img)
  22. ff += outputs
  23. # # norm feature
  24. fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
  25. ff = ff.div(fnorm.expand_as(ff))
  26. features = torch.cat((features, ff.data.cpu()), 0)
  27. return features
  28. def test_evaluate(result, device):
  29. query_feature = torch.FloatTensor(result['query_f'])
  30. query_cam = result['query_cam'][0]
  31. query_label = result['query_label'][0]
  32. gallery_feature = torch.FloatTensor(result['gallery_f'])
  33. gallery_cam = result['gallery_cam'][0]
  34. gallery_label = result['gallery_label'][0]
  35. query_feature = query_feature.to(device)
  36. gallery_feature = gallery_feature.to(device)
  37. CMC = torch.IntTensor(len(gallery_label)).zero_()
  38. ap = 0.0
  39. for i in range(len(query_label)):
  40. ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], query_cam[i], gallery_feature, gallery_label,
  41. gallery_cam)
  42. if CMC_tmp[0] == -1:
  43. continue
  44. CMC = CMC + CMC_tmp
  45. ap += ap_tmp
  46. CMC = CMC.float()
  47. CMC = CMC / len(query_label) # average CMC
  48. return CMC[0], CMC[4], CMC[9], ap / len(query_label)
  49. def evaluate(qf, ql, qc, gf, gl, gc):
  50. query = qf.view(-1, 1)
  51. score = torch.mm(gf, query)
  52. score = score.squeeze(1).cpu()
  53. score = score.numpy()
  54. # predict index
  55. index = np.argsort(score) # from small to large
  56. index = index[::-1]
  57. # good index
  58. query_index = np.argwhere(gl == ql)
  59. camera_index = np.argwhere(gc == qc)
  60. good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
  61. junk_index1 = np.argwhere(gl == -1)
  62. junk_index2 = np.intersect1d(query_index, camera_index)
  63. junk_index = np.append(junk_index2, junk_index1) # .flatten())
  64. CMC_tmp = compute_mAP(index, good_index, junk_index)
  65. return CMC_tmp
  66. def compute_mAP(index, good_index, junk_index):
  67. ap = 0
  68. cmc = torch.IntTensor(len(index)).zero_()
  69. if good_index.size == 0: # if empty
  70. cmc[0] = -1
  71. return ap, cmc
  72. # remove junk_index
  73. mask = np.in1d(index, junk_index, invert=True)
  74. index = index[mask]
  75. # find good_index index
  76. ngood = len(good_index)
  77. mask = np.in1d(index, good_index)
  78. rows_good = np.argwhere(mask == True)
  79. rows_good = rows_good.flatten()
  80. cmc[rows_good[0]:] = 1
  81. for i in range(ngood):
  82. d_recall = 1.0 / ngood
  83. precision = (i + 1) * 1.0 / (rows_good[i] + 1)
  84. if rows_good[i] != 0:
  85. old_precision = i * 1.0 / rows_good[i]
  86. else:
  87. old_precision = 1.0
  88. ap = ap + d_recall * (old_precision + precision) / 2
  89. return ap, cmc
  90. def fliplr(img):
  91. inv_idx = torch.arange(img.size(3) - 1, -1, -1).long() # N x C x H x W
  92. img_flip = img.index_select(3, inv_idx)
  93. return img_flip