clientfomo.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import copy
  5. from flcore.clients.clientbase import Client
  6. from torch.utils.data import DataLoader
  7. from utils.data_utils import read_client_data
  8. class clientFomo(Client):
  9. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  10. super().__init__(args, id, train_samples, test_samples, **kwargs)
  11. self.num_clients = args.num_clients
  12. self.old_model = copy.deepcopy(self.model)
  13. self.received_ids = []
  14. self.received_models = []
  15. self.weight_vector = torch.zeros(self.num_clients, device=self.device)
  16. self.criterion = nn.CrossEntropyLoss()
  17. self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
  18. self.val_ratio = 0.2
  19. self.train_samples = self.train_samples * (1-self.val_ratio)
  20. def train(self):
  21. trainloader, val_loader = self.load_train_data()
  22. self.aggregate_parameters(val_loader)
  23. self.clone_model(self.model, self.old_model)
  24. # self.model.to(self.device)
  25. self.model.train()
  26. max_local_steps = self.local_steps
  27. for step in range(max_local_steps):
  28. for x, y in trainloader:
  29. if type(x) == type([]):
  30. x[0] = x[0].to(self.device)
  31. else:
  32. x = x.to(self.device)
  33. y = y.to(self.device)
  34. self.optimizer.zero_grad()
  35. output = self.model(x)
  36. loss = self.criterion(output, y)
  37. loss.backward()
  38. self.optimizer.step()
  39. # self.model.cpu()
  40. def standard_train(self):
  41. trainloader, val_loader = self.load_train_data()
  42. self.model.train()
  43. optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
  44. # 1 epoch
  45. for i, (x, y) in enumerate(trainloader):
  46. if type(x) == type([]):
  47. x[0] = x[0].to(self.device)
  48. else:
  49. x = x.to(self.device)
  50. y = y.to(self.device)
  51. optimizer.zero_grad()
  52. output = self.model(x)
  53. loss = self.criterion(output, y)
  54. loss.backward()
  55. optimizer.step()
  56. def load_train_data(self, batch_size=None):
  57. if batch_size == None:
  58. batch_size = self.batch_size
  59. train_data = read_client_data(self.dataset, self.id, is_train=True)
  60. val_idx = -int(self.val_ratio*len(train_data))
  61. val_data = train_data[val_idx:]
  62. train_data = train_data[:val_idx]
  63. trainloader = DataLoader(train_data, self.batch_size, drop_last=True, shuffle=True)
  64. val_loader = DataLoader(val_data, self.batch_size, drop_last=self.has_BatchNorm, shuffle=True)
  65. return trainloader, val_loader
  66. def receive_models(self, ids, models):
  67. self.received_ids = ids
  68. self.received_models = models
  69. def weight_cal(self, val_loader):
  70. weight_list = []
  71. L = self.recalculate_loss(self.old_model, val_loader)
  72. for received_model in self.received_models:
  73. params_dif = []
  74. for param_n, param_i in zip(received_model.parameters(), self.old_model.parameters()):
  75. params_dif.append((param_n - param_i).view(-1))
  76. params_dif = torch.cat(params_dif)
  77. d = L - self.recalculate_loss(received_model, val_loader)
  78. if d > 0:
  79. weight_list.append((d / (torch.norm(params_dif) + 1e-5)).item())
  80. else:
  81. weight_list.append(0.0)
  82. if len(weight_list) != 0:
  83. weight_list = np.array(weight_list)
  84. weight_list /= (np.sum(weight_list) + 1e-10)
  85. self.weight_vector_update(weight_list)
  86. return torch.tensor(weight_list)
  87. def weight_vector_update(self, weight_list):
  88. self.weight_vector = np.zeros(self.num_clients)
  89. for w, id in zip(weight_list, self.received_ids):
  90. self.weight_vector[id] += w.item()
  91. self.weight_vector = torch.tensor(self.weight_vector).to(self.device)
  92. def recalculate_loss(self, new_model, val_loader):
  93. L = 0
  94. for x, y in val_loader:
  95. if type(x) == type([]):
  96. x[0] = x[0].to(self.device)
  97. else:
  98. x = x.to(self.device)
  99. y = y.to(self.device)
  100. output = new_model(x)
  101. loss = self.criterion(output, y)
  102. L += (loss * y.shape[0]).item()
  103. return L / len(val_loader.dataset)
  104. def add_parameters(self, w, received_model):
  105. for param, received_param in zip(self.model.parameters(), received_model.parameters()):
  106. param.data += received_param.data.clone() * w
  107. def aggregate_parameters(self, val_loader):
  108. weights = self.weight_cal(val_loader)
  109. if len(weights) > 0 and sum(weights) > 0.0:
  110. for param in self.model.parameters():
  111. param.data.zero_()
  112. for w, received_model in zip(weights, self.received_models):
  113. self.add_parameters(w, received_model)
  114. def weight_scale(self, weights):
  115. weights = torch.maximum(weights, torch.tensor(0))
  116. w_sum = torch.sum(weights)
  117. if w_sum > 0:
  118. weights = [w/w_sum for w in weights]
  119. return torch.tensor(weights)
  120. else:
  121. return torch.tensor([])