knn_monitor.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import torch
  2. import torch.nn.functional as F
  3. from tqdm import tqdm
  4. # code is obtained from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=lzFyFnhbk8hj
  5. # test using a knn monitor
  6. def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1, hide_progress=False, device=None):
  7. net.eval()
  8. classes = len(memory_data_loader.dataset.classes)
  9. total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
  10. with torch.no_grad():
  11. # generate feature bank
  12. for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=hide_progress):
  13. if device is None:
  14. data = data.cuda(non_blocking=True)
  15. else:
  16. data = data.to(device, non_blocking=True)
  17. feature = net(data)
  18. feature = F.normalize(feature, dim=1)
  19. feature_bank.append(feature)
  20. # [D, N]
  21. feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
  22. # [N]
  23. feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
  24. # loop test data to predict the label by weighted knn search
  25. test_bar = tqdm(test_data_loader, desc='kNN', disable=hide_progress)
  26. for data, target in test_bar:
  27. if device is None:
  28. data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
  29. else:
  30. data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
  31. feature = net(data)
  32. feature = F.normalize(feature, dim=1)
  33. pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t)
  34. total_num += data.size(0)
  35. total_top1 += (pred_labels[:, 0] == target).float().sum().item()
  36. test_bar.set_postfix({'Accuracy': total_top1 / total_num * 100})
  37. print("Accuracy: {}".format(total_top1 / total_num * 100))
  38. return total_top1 / total_num * 100
  39. # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
  40. # implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
  41. def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
  42. # compute cos similarity between each feature vector and feature bank ---> [B, N]
  43. sim_matrix = torch.mm(feature, feature_bank)
  44. # [B, K]
  45. sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
  46. # [B, K]
  47. sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
  48. sim_weight = (sim_weight / knn_t).exp()
  49. # counts for each class
  50. one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
  51. # [B*K, C]
  52. one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
  53. # weighted score ---> [B, C]
  54. pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
  55. pred_labels = pred_scores.argsort(dim=-1, descending=True)
  56. return pred_labels