clientbase.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. import os
  6. import torch.nn.functional as F
  7. from torch.utils.data import DataLoader
  8. from sklearn.preprocessing import label_binarize
  9. from sklearn import metrics
  10. from utils.data_utils import read_client_data
  11. class Client(object):
  12. """
  13. Base class for clients in federated learning.
  14. """
  15. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  16. self.model = copy.deepcopy(args.model)
  17. self.dataset = args.dataset
  18. self.device = args.device
  19. self.id = id # integer
  20. self.num_classes = args.num_classes
  21. self.train_samples = train_samples
  22. self.test_samples = test_samples
  23. self.batch_size = args.batch_size
  24. self.learning_rate = args.local_learning_rate
  25. self.local_steps = args.local_steps
  26. # check BatchNorm
  27. self.has_BatchNorm = False
  28. for layer in self.model.children():
  29. if isinstance(layer, nn.BatchNorm2d):
  30. self.has_BatchNorm = True
  31. break
  32. self.sample_rate = self.batch_size / self.train_samples
  33. def load_train_data(self, batch_size=None):
  34. if batch_size == None:
  35. batch_size = self.batch_size
  36. train_data = read_client_data(self.dataset, self.id, is_train=True)
  37. batch_size = min(batch_size, len(train_data))
  38. return DataLoader(train_data, batch_size, drop_last=True, shuffle=True)
  39. def load_test_data(self, batch_size=None):
  40. if batch_size == None:
  41. batch_size = self.batch_size
  42. test_data = read_client_data(self.dataset, self.id, is_train=False)
  43. batch_size = min(batch_size, len(test_data))
  44. return DataLoader(test_data, batch_size, drop_last=False, shuffle=True)
  45. def set_parameters(self, model):
  46. for new_param, old_param in zip(model.parameters(), self.model.parameters()):
  47. old_param.data = new_param.data.clone()
  48. def clone_model(self, model, target):
  49. for param, target_param in zip(model.parameters(), target.parameters()):
  50. target_param.data = param.data.clone()
  51. # target_param.grad = param.grad.clone()
  52. def update_parameters(self, model, new_params):
  53. for param, new_param in zip(model.parameters(), new_params):
  54. param.data = new_param.data.clone()
  55. def get_eval_model(self, temp_model=None):
  56. model = self.model_per if hasattr(self, "model_per") else self.model
  57. return model
  58. def standard_train(self):
  59. trainloader = self.load_train_data()
  60. self.model.train()
  61. for p in self.model.parameters():
  62. p.requires_grad = True
  63. optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
  64. # 1 epoch
  65. for i, (x, y) in enumerate(trainloader):
  66. if type(x) == type([]):
  67. x[0] = x[0].to(self.device)
  68. else:
  69. x = x.to(self.device)
  70. y = y.to(self.device)
  71. optimizer.zero_grad()
  72. output = self.model(x)
  73. loss = self.criterion(output, y)
  74. loss.backward()
  75. optimizer.step()
  76. def test_metrics(self, temp_model=None):
  77. testloaderfull = self.load_test_data()
  78. # self.model = self.load_model('model')
  79. # self.model.to(self.device)
  80. model = self.get_eval_model(temp_model)
  81. model.eval()
  82. test_correct = 0
  83. test_num = 0
  84. test_loss = 0.0
  85. y_prob = []
  86. y_true = []
  87. with torch.no_grad():
  88. for x, y in testloaderfull:
  89. if type(x) == type([]):
  90. x[0] = x[0].to(self.device)
  91. else:
  92. x = x.to(self.device)
  93. y = y.to(self.device)
  94. output = model(x)
  95. test_loss += (self.criterion(output, y.long()) * y.shape[0]).item() # sum up batch loss
  96. test_correct += (torch.sum(torch.argmax(output, dim=1) == y)).item()
  97. test_num += y.shape[0]
  98. y_prob.append(output.detach().cpu().numpy())
  99. y_true.append(label_binarize(y.detach().cpu().numpy(), classes=np.arange(self.num_classes)))
  100. # self.model.cpu()
  101. # self.save_model(self.model, 'model')
  102. y_prob = np.concatenate(y_prob, axis=0)
  103. y_true = np.concatenate(y_true, axis=0)
  104. try:
  105. test_auc = metrics.roc_auc_score(y_true, y_prob, average='micro')
  106. test_loss /= test_num
  107. except ValueError:
  108. test_auc, test_loss = 0.0, 0.0
  109. test_acc = test_correct / test_num
  110. return test_acc, test_auc, test_loss, test_num