semi_supervised_evaluation.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import argparse
  2. from collections import defaultdict
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from easyfl.datasets.data import CIFAR100
  7. from eval_dataset import get_semi_supervised_data_loaders
  8. from model import get_encoder_network
  9. def test_whole(resnet, logreg, device, test_loader, model_path):
  10. print("### Calculating final testing performance ###")
  11. resnet.eval()
  12. logreg.eval()
  13. metrics = defaultdict(list)
  14. for step, (h, y) in enumerate(test_loader):
  15. h = h.to(device)
  16. y = y.to(device)
  17. with torch.no_grad():
  18. outputs = logreg(resnet(h))
  19. # calculate accuracy and save metrics
  20. accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
  21. metrics["Accuracy/test"].append(accuracy)
  22. print(f"Final test performance: " + "\t".join([f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))
  23. return np.array(metrics["Accuracy/test"]).mean()
  24. def finetune_internal(model, epochs, label_loader, test_loader, num_class, device, lr=3e-3):
  25. model = model.to(device)
  26. num_features = model.feature_dim
  27. n_classes = num_class # e.g. CIFAR-10 has 10 classes
  28. # fine-tune model
  29. logreg = nn.Sequential(nn.Linear(num_features, n_classes))
  30. logreg = logreg.to(device)
  31. # loss / optimizer
  32. criterion = nn.CrossEntropyLoss()
  33. optimizer = torch.optim.Adam(params=logreg.parameters(), lr=lr)
  34. # Train fine-tuned model
  35. model.train()
  36. logreg.train()
  37. for epoch in range(epochs):
  38. metrics = defaultdict(list)
  39. for step, (h, y) in enumerate(label_loader):
  40. h = h.to(device)
  41. y = y.to(device)
  42. outputs = logreg(model(h))
  43. loss = criterion(outputs, y)
  44. optimizer.zero_grad()
  45. loss.backward()
  46. optimizer.step()
  47. # calculate accuracy and save metrics
  48. accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
  49. metrics["Loss/train"].append(loss.item())
  50. metrics["Accuracy/train"].append(accuracy)
  51. if epoch % 100 == 0:
  52. print("======epoch {}======".format(epoch))
  53. test_whole(model, logreg, device, test_loader, "test_whole")
  54. final_accuracy = test_whole(model, logreg, device, test_loader, "test_whole")
  55. print(metrics)
  56. return final_accuracy
  57. class MLP(nn.Module):
  58. def __init__(self, dim, projection_size, hidden_size=4096):
  59. super().__init__()
  60. self.net = nn.Sequential(
  61. nn.Linear(dim, hidden_size),
  62. nn.BatchNorm1d(hidden_size),
  63. nn.ReLU(inplace=True),
  64. nn.Linear(hidden_size, projection_size),
  65. )
  66. def forward(self, x):
  67. return self.net(x)
  68. if __name__ == "__main__":
  69. parser = argparse.ArgumentParser()
  70. parser.add_argument("--dataset", default="cifar10", type=str, help="cifar10/cifar100.")
  71. parser.add_argument('--model', default='simsiam', type=str, help='name of the network')
  72. parser.add_argument("--encoder_network", default="resnet18", type=str, help="Encoder network architecture.")
  73. parser.add_argument("--model_path", required=True, type=str, help="Path to pre-trained model (e.g. model-10.pt)")
  74. parser.add_argument("--image_size", default=32, type=int, help="Image size")
  75. parser.add_argument("--learning_rate", default=1e-3, type=float, help="Initial learning rate.")
  76. parser.add_argument("--batch_size", default=128, type=int, help="Batch size for training.")
  77. parser.add_argument("--num_epochs", default=100, type=int, help="Number of epochs to train for.")
  78. parser.add_argument("--data_distribution", default="class", type=str, help="class/iid")
  79. parser.add_argument("--label_ratio", default=0.01, type=float, help="ratio of labeled data for fine tune")
  80. parser.add_argument('--class_per_client', default=2, type=int,
  81. help='for non-IID setting, number of class each client, based on CIFAR10')
  82. parser.add_argument("--use_MLP", action='store_true',
  83. help="whether use MLP, if use, one hidden layer MLP, else, Linear Layer.")
  84. parser.add_argument("--num_workers", default=8, type=int,
  85. help="Number of data loading workers (caution with nodes!)")
  86. args = parser.parse_args()
  87. print(args)
  88. device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
  89. print('==> Preparing data..')
  90. class_per_client = args.class_per_client
  91. n_classes = 10
  92. if args.dataset == CIFAR100:
  93. class_per_client = 10 * class_per_client
  94. n_classes = 100
  95. train_loader, test_loader = get_semi_supervised_data_loaders(args.dataset,
  96. args.data_distribution,
  97. class_per_client,
  98. args.label_ratio,
  99. args.batch_size,
  100. args.num_workers)
  101. print('==> Building model..')
  102. resnet = get_encoder_network(args.model, args.encoder_network)
  103. resnet.load_state_dict(torch.load(args.model_path, map_location=device))
  104. resnet = resnet.to(device)
  105. num_features = list(resnet.children())[-1].in_features
  106. resnet.fc = nn.Identity()
  107. # fine-tune model
  108. if args.use_MLP:
  109. logreg = MLP(num_features, n_classes, 4096)
  110. logreg = logreg.to(device)
  111. else:
  112. logreg = nn.Sequential(nn.Linear(num_features, n_classes))
  113. logreg = logreg.to(device)
  114. # loss / optimizer
  115. criterion = nn.CrossEntropyLoss()
  116. optimizer = torch.optim.Adam(params=logreg.parameters(), lr=args.learning_rate)
  117. # Train fine-tuned model
  118. logreg.train()
  119. resnet.train()
  120. accs = []
  121. for epoch in range(args.num_epochs):
  122. print("======epoch {}======".format(epoch))
  123. metrics = defaultdict(list)
  124. for step, (h, y) in enumerate(train_loader):
  125. h = h.to(device)
  126. y = y.to(device)
  127. outputs = logreg(resnet(h))
  128. loss = criterion(outputs, y)
  129. optimizer.zero_grad()
  130. loss.backward()
  131. optimizer.step()
  132. # calculate accuracy and save metrics
  133. accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
  134. metrics["Loss/train"].append(loss.item())
  135. metrics["Accuracy/train"].append(accuracy)
  136. print(f"Epoch [{epoch}/{args.num_epochs}]: " + "\t".join(
  137. [f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))
  138. if epoch % 1 == 0:
  139. acc = test_whole(resnet, logreg, device, test_loader, args.model_path)
  140. if epoch <= 100:
  141. accs.append(acc)
  142. test_whole(resnet, logreg, device, test_loader, args.model_path)
  143. print(args.model_path)
  144. print(f"Best one for 100 epoch is {max(accs):.4f}")