linear_evaluation.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import argparse
  2. from collections import defaultdict
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from easyfl.datasets.data import CIFAR100
  7. from eval_dataset import get_data_loaders
  8. from model import get_encoder_network
  9. def inference(loader, model, device):
  10. feature_vector = []
  11. labels_vector = []
  12. model.eval()
  13. for step, (x, y) in enumerate(loader):
  14. x = x.to(device)
  15. # get encoding
  16. with torch.no_grad():
  17. h = model(x)
  18. h = h.squeeze()
  19. h = h.detach()
  20. feature_vector.extend(h.cpu().detach().numpy())
  21. labels_vector.extend(y.numpy())
  22. if step % 5 == 0:
  23. print(f"Step [{step}/{len(loader)}]\t Computing features...")
  24. feature_vector = np.array(feature_vector)
  25. labels_vector = np.array(labels_vector)
  26. print("Features shape {}".format(feature_vector.shape))
  27. return feature_vector, labels_vector
  28. def get_features(model, train_loader, test_loader, device):
  29. train_X, train_y = inference(train_loader, model, device)
  30. test_X, test_y = inference(test_loader, model, device)
  31. return train_X, train_y, test_X, test_y
  32. def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
  33. train = torch.utils.data.TensorDataset(
  34. torch.from_numpy(X_train), torch.from_numpy(y_train)
  35. )
  36. train_loader = torch.utils.data.DataLoader(
  37. train, batch_size=batch_size, shuffle=False
  38. )
  39. test = torch.utils.data.TensorDataset(
  40. torch.from_numpy(X_test), torch.from_numpy(y_test)
  41. )
  42. test_loader = torch.utils.data.DataLoader(
  43. test, batch_size=batch_size, shuffle=False
  44. )
  45. return train_loader, test_loader
  46. def test_result(test_loader, logreg, device, model_path):
  47. # Test fine-tuned model
  48. print("### Calculating final testing performance ###")
  49. logreg.eval()
  50. metrics = defaultdict(list)
  51. for step, (h, y) in enumerate(test_loader):
  52. h = h.to(device)
  53. y = y.to(device)
  54. outputs = logreg(h)
  55. # calculate accuracy and save metrics
  56. accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
  57. metrics["Accuracy/test"].append(accuracy)
  58. print(f"Final test performance: " + model_path)
  59. for k, v in metrics.items():
  60. print(f"{k}: {np.array(v).mean():.4f}")
  61. return np.array(metrics["Accuracy/test"]).mean()
  62. if __name__ == "__main__":
  63. parser = argparse.ArgumentParser()
  64. parser.add_argument("--dataset", default="cifar10", type=str)
  65. parser.add_argument("--model_path", required=True, type=str, help="Path to pre-trained model (e.g. model-10.pt)")
  66. parser.add_argument('--model', default='simsiam', type=str, help='name of the network')
  67. parser.add_argument("--image_size", default=32, type=int, help="Image size")
  68. parser.add_argument("--learning_rate", default=3e-3, type=float, help="Initial learning rate.")
  69. parser.add_argument("--batch_size", default=512, type=int, help="Batch size for training.")
  70. parser.add_argument("--num_epochs", default=200, type=int, help="Number of epochs to train for.")
  71. parser.add_argument("--encoder_network", default="resnet18", type=str, help="Encoder network architecture.")
  72. parser.add_argument("--num_workers", default=8, type=int, help="Number of data workers (caution with nodes!)")
  73. parser.add_argument("--fc", default="identity", help="options: identity, remove")
  74. args = parser.parse_args()
  75. print(args)
  76. device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
  77. # get data loaders
  78. train_loader, test_loader = get_data_loaders(args.dataset, args.image_size, args.batch_size, args.num_workers)
  79. # get model
  80. resnet = get_encoder_network(args.model, args.encoder_network)
  81. resnet.load_state_dict(torch.load(args.model_path, map_location=device))
  82. resnet = resnet.to(device)
  83. num_features = list(resnet.children())[-1].in_features
  84. if args.fc == "remove":
  85. resnet = nn.Sequential(*list(resnet.children())[:-1]) # throw away fc layer
  86. else:
  87. resnet.fc = nn.Identity()
  88. n_classes = 10
  89. if args.dataset == CIFAR100:
  90. n_classes = 100
  91. # fine-tune model
  92. logreg = nn.Sequential(nn.Linear(num_features, n_classes))
  93. logreg = logreg.to(device)
  94. # loss / optimizer
  95. criterion = nn.CrossEntropyLoss()
  96. optimizer = torch.optim.Adam(params=logreg.parameters(), lr=args.learning_rate)
  97. # compute features (only needs to be done once, since it does not backprop during fine-tuning)
  98. print("Creating features from pre-trained model")
  99. (train_X, train_y, test_X, test_y) = get_features(
  100. resnet, train_loader, test_loader, device
  101. )
  102. train_loader, test_loader = create_data_loaders_from_arrays(
  103. train_X, train_y, test_X, test_y, 2048
  104. )
  105. # Train fine-tuned model
  106. logreg.train()
  107. for epoch in range(args.num_epochs):
  108. metrics = defaultdict(list)
  109. for step, (h, y) in enumerate(train_loader):
  110. h = h.to(device)
  111. y = y.to(device)
  112. outputs = logreg(h)
  113. loss = criterion(outputs, y)
  114. optimizer.zero_grad()
  115. loss.backward()
  116. optimizer.step()
  117. # calculate accuracy and save metrics
  118. accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
  119. metrics["Loss/train"].append(loss.item())
  120. metrics["Accuracy/train"].append(accuracy)
  121. print(f"Epoch [{epoch}/{args.num_epochs}]: " + "\t".join(
  122. [f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))
  123. if epoch % 100 == 0:
  124. print("======epoch {}======".format(epoch))
  125. test_result(test_loader, logreg, device, args.model_path)
  126. test_result(test_loader, logreg, device, args.model_path)