ranking.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from __future__ import absolute_import
  2. from collections import defaultdict
  3. import numpy as np
  4. from sklearn.metrics.base import _average_binary_score
  5. from sklearn.metrics import precision_recall_curve, auc
  6. # from sklearn.metrics import average_precision_score
  7. from ..utils import to_numpy
  8. def _unique_sample(ids_dict, num):
  9. mask = np.zeros(num, dtype=np.bool)
  10. for _, indices in ids_dict.items():
  11. i = np.random.choice(indices)
  12. mask[i] = True
  13. return mask
  14. def average_precision_score(y_true, y_score, average="macro",
  15. sample_weight=None):
  16. def _binary_average_precision(y_true, y_score, sample_weight=None):
  17. precision, recall, thresholds = precision_recall_curve(
  18. y_true, y_score, sample_weight=sample_weight)
  19. return auc(recall, precision)
  20. return _average_binary_score(_binary_average_precision, y_true, y_score,
  21. average, sample_weight=sample_weight)
  22. def cmc(distmat, query_ids=None, gallery_ids=None,
  23. query_cams=None, gallery_cams=None, topk=100,
  24. separate_camera_set=False,
  25. single_gallery_shot=False,
  26. first_match_break=False):
  27. distmat = to_numpy(distmat)
  28. m, n = distmat.shape
  29. # Fill up default values
  30. if query_ids is None:
  31. query_ids = np.arange(m)
  32. if gallery_ids is None:
  33. gallery_ids = np.arange(n)
  34. if query_cams is None:
  35. query_cams = np.zeros(m).astype(np.int32)
  36. if gallery_cams is None:
  37. gallery_cams = np.ones(n).astype(np.int32)
  38. # Ensure numpy array
  39. query_ids = np.asarray(query_ids)
  40. gallery_ids = np.asarray(gallery_ids)
  41. query_cams = np.asarray(query_cams)
  42. gallery_cams = np.asarray(gallery_cams)
  43. # Sort and find correct matches
  44. indices = np.argsort(distmat, axis=1)
  45. matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
  46. # Compute CMC for each query
  47. ret = np.zeros(topk)
  48. num_valid_queries = 0
  49. for i in range(m):
  50. # Filter out the same id and same camera
  51. valid = ((gallery_ids[indices[i]] != query_ids[i]) |
  52. (gallery_cams[indices[i]] != query_cams[i]))
  53. if separate_camera_set:
  54. # Filter out samples from same camera
  55. valid &= (gallery_cams[indices[i]] != query_cams[i])
  56. if not np.any(matches[i, valid]): continue
  57. if single_gallery_shot:
  58. repeat = 10
  59. gids = gallery_ids[indices[i][valid]]
  60. inds = np.where(valid)[0]
  61. ids_dict = defaultdict(list)
  62. for j, x in zip(inds, gids):
  63. ids_dict[x].append(j)
  64. else:
  65. repeat = 1
  66. for _ in range(repeat):
  67. if single_gallery_shot:
  68. # Randomly choose one instance for each id
  69. sampled = (valid & _unique_sample(ids_dict, len(valid)))
  70. index = np.nonzero(matches[i, sampled])[0]
  71. else:
  72. index = np.nonzero(matches[i, valid])[0]
  73. delta = 1. / (len(index) * repeat)
  74. for j, k in enumerate(index):
  75. if k - j >= topk: break
  76. if first_match_break:
  77. ret[k - j] += 1
  78. break
  79. ret[k - j] += delta
  80. num_valid_queries += 1
  81. if num_valid_queries == 0:
  82. raise RuntimeError("No valid query")
  83. return ret.cumsum() / num_valid_queries
  84. def mean_ap(distmat, query_ids=None, gallery_ids=None,
  85. query_cams=None, gallery_cams=None):
  86. distmat = to_numpy(distmat)
  87. m, n = distmat.shape
  88. # Fill up default values
  89. if query_ids is None:
  90. query_ids = np.arange(m)
  91. if gallery_ids is None:
  92. gallery_ids = np.arange(n)
  93. if query_cams is None:
  94. query_cams = np.zeros(m).astype(np.int32)
  95. if gallery_cams is None:
  96. gallery_cams = np.ones(n).astype(np.int32)
  97. # Ensure numpy array
  98. query_ids = np.asarray(query_ids)
  99. gallery_ids = np.asarray(gallery_ids)
  100. query_cams = np.asarray(query_cams)
  101. gallery_cams = np.asarray(gallery_cams)
  102. # Sort and find correct matches
  103. indices = np.argsort(distmat, axis=1)
  104. matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
  105. # Compute AP for each query
  106. aps = []
  107. for i in range(m):
  108. # Filter out the same id and same camera
  109. valid = ((gallery_ids[indices[i]] != query_ids[i]) |
  110. (gallery_cams[indices[i]] != query_cams[i]))
  111. y_true = matches[i, valid]
  112. y_score = -distmat[i][indices[i]][valid]
  113. if not np.any(y_true): continue
  114. aps.append(average_precision_score(y_true, y_score))
  115. if len(aps) == 0:
  116. raise RuntimeError("No valid query")
  117. return np.mean(aps)