serverbase.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import torch
  2. import os
  3. import numpy as np
  4. import h5py
  5. import copy
  6. import time
  7. import sys
  8. import random
  9. import logging
  10. from utils.data_utils import read_client_data
  11. class Server(object):
  12. def __init__(self, args, times):
  13. # Set up the main attributes
  14. self.device = args.device
  15. self.dataset = args.dataset
  16. self.global_rounds = args.global_rounds
  17. self.local_steps = args.local_steps
  18. self.batch_size = args.batch_size
  19. self.learning_rate = args.local_learning_rate
  20. self.global_model = copy.deepcopy(args.model)
  21. self.num_clients = args.num_clients
  22. self.join_ratio = args.join_ratio
  23. self.join_clients = int(self.num_clients * self.join_ratio)
  24. self.algorithm = args.algorithm
  25. self.goal = args.goal
  26. self.top_cnt = 100
  27. self.best_mean_test_acc = -1.0
  28. self.clients = []
  29. self.selected_clients = []
  30. self.uploaded_weights = []
  31. self.uploaded_ids = []
  32. self.uploaded_models = []
  33. self.rs_test_acc = []
  34. self.rs_test_auc = []
  35. self.rs_test_loss = []
  36. self.rs_train_loss = []
  37. self.clients_test_accs = []
  38. self.domain_mean_test_accs = []
  39. self.times = times
  40. self.eval_gap = args.eval_gap
  41. self.set_seed(self.times)
  42. self.set_path(args)
  43. # preprocess dataset name
  44. if self.dataset.startswith("cifar"):
  45. dir_alpha = 0.3
  46. elif self.dataset == "organamnist25":
  47. dir_alpha = 1.0
  48. elif self.dataset.startswith("organamnist"):
  49. dir_alpha = 0.3
  50. elif self.dataset.startswith("organamnist"):
  51. if self.num_clients == 20:
  52. dir_alpha = 0.3
  53. else:
  54. dir_alpha = 1.0
  55. else:
  56. dir_alpha = float("nan")
  57. self.actual_dataset = f"{self.dataset}-{self.num_clients}clients_alpha{dir_alpha:.1f}"
  58. logger_fn = os.path.join(args.log_dir, f"{args.algorithm}-{self.actual_dataset}.log")
  59. self.set_logger(save=True, fn=logger_fn)
  60. def set_seed(self, seed):
  61. np.random.seed(seed)
  62. random.seed(seed)
  63. torch.manual_seed(seed)
  64. if torch.cuda.is_available():
  65. torch.cuda.manual_seed(seed)
  66. torch.cuda.manual_seed_all(seed)
  67. torch.backends.cudnn.enabled = False
  68. torch.backends.cudnn.benchmark = False
  69. torch.backends.cudnn.deterministic = True
  70. def set_logger(self, save=False, fn=None):
  71. if save:
  72. fn = "testlog.log" if fn == None else fn
  73. logging.basicConfig(
  74. filename=fn,
  75. filemode="a",
  76. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  77. level=logging.DEBUG
  78. )
  79. else:
  80. logging.basicConfig(
  81. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
  82. level=logging.DEBUG
  83. )
  84. def set_path(self, args):
  85. self.hist_dir = args.hist_dir
  86. self.log_dir = args.log_dir
  87. self.ckpt_dir = args.ckpt_dir
  88. if not os.path.exists(args.hist_dir):
  89. os.makedirs(args.hist_dir)
  90. if not os.path.exists(args.log_dir):
  91. os.makedirs(args.log_dir)
  92. if not os.path.exists(args.ckpt_dir):
  93. os.makedirs(args.ckpt_dir)
  94. def set_clients(self, args, clientObj):
  95. self.new_clients = None
  96. for i in range(self.num_clients):
  97. train_data = read_client_data(self.dataset, i, is_train=True)
  98. test_data = read_client_data(self.dataset, i, is_train=False)
  99. client = clientObj(args,
  100. id=i,
  101. train_samples=len(train_data),
  102. test_samples=len(test_data))
  103. self.clients.append(client)
  104. def select_clients(self):
  105. selected_clients = list(np.random.choice(self.clients, self.join_clients, replace=False))
  106. return selected_clients
  107. def send_models(self, mode="selected"):
  108. if mode == "selected":
  109. assert (len(self.selected_clients) > 0)
  110. for client in self.selected_clients:
  111. client.set_parameters(self.global_model)
  112. elif mode == "all":
  113. for client in self.clients:
  114. client.set_parameters(self.global_model)
  115. else:
  116. raise NotImplementedError
  117. def receive_models(self):
  118. assert (len(self.selected_clients) > 0)
  119. self.uploaded_weights = []
  120. tot_samples = 0
  121. self.uploaded_ids = []
  122. self.uploaded_models = []
  123. for client in self.selected_clients:
  124. self.uploaded_weights.append(client.train_samples)
  125. tot_samples += client.train_samples
  126. self.uploaded_ids.append(client.id)
  127. self.uploaded_models.append(client.model)
  128. for i, w in enumerate(self.uploaded_weights):
  129. self.uploaded_weights[i] = w / tot_samples
  130. def aggregate_parameters(self):
  131. assert (len(self.uploaded_models) > 0)
  132. self.global_model = copy.deepcopy(self.uploaded_models[0])
  133. for param in self.global_model.parameters():
  134. param.data.zero_()
  135. for w, client_model in zip(self.uploaded_weights, self.uploaded_models):
  136. self.add_parameters(w, client_model)
  137. def add_parameters(self, w, client_model):
  138. for server_param, client_param in zip(self.global_model.parameters(), client_model.parameters()):
  139. server_param.data += client_param.data.clone() * w
  140. def prepare_global_model(self):
  141. pass
  142. def reset_records(self):
  143. self.best_mean_test_acc = 0.0
  144. self.clients_test_accs = []
  145. self.rs_test_acc = []
  146. self.rs_test_auc = []
  147. self.rs_test_loss = []
  148. def train_new_clients(self, epochs=20):
  149. self.global_model = self.global_model.to(self.device)
  150. self.clients = self.new_clients
  151. self.reset_records()
  152. for c in self.clients:
  153. c.model = copy.deepcopy(self.global_model)
  154. self.evaluate()
  155. for epoch_idx in range(epochs):
  156. for c in self.clients:
  157. c.standard_train()
  158. print(f"==> New clients epoch: [{epoch_idx+1:2d}/{epochs}] | Evaluating local models...", flush=True)
  159. self.evaluate()
  160. print(f"==> Best mean global accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
  161. self.save_results(fn=self.hist_result_fn)
  162. message_res = f"\tnew_clients_test_acc:{self.best_mean_test_acc:.6f}"
  163. logging.info(self.message_hp + message_res)
  164. def save_global_model(self, model_path=None, state=None):
  165. if model_path is None:
  166. model_path = os.path.join("models", self.dataset)
  167. if not os.path.exists(model_path):
  168. os.makedirs(model_path)
  169. model_path = os.path.join(model_path, self.algorithm + "_server" + ".pt")
  170. if state is None:
  171. torch.save({"global_model": self.global_model.cpu().state_dict()}, model_path)
  172. else:
  173. torch.save(state, model_path)
  174. def load_model(self, model_path=None):
  175. if model_path is None:
  176. model_path = os.path.join("models", self.dataset)
  177. model_path = os.path.join(model_path, self.algorithm + "_server" + ".pt")
  178. assert (os.path.exists(model_path))
  179. self.global_model = torch.load(model_path)
  180. def save_results(self, fn=None):
  181. if fn is None:
  182. algo = self.dataset + "_" + self.algorithm
  183. result_path = self.hist_dir
  184. if (len(self.rs_test_acc)):
  185. if fn is None:
  186. algo = algo + "_" + self.goal + "_" + str(self.times+1)
  187. file_path = os.path.join(result_path, "{}.h5".format(algo))
  188. else:
  189. file_path = fn
  190. print("File path: " + file_path)
  191. with h5py.File(file_path, 'w') as hf:
  192. hf.create_dataset('rs_test_acc', data=self.rs_test_acc)
  193. hf.create_dataset('rs_test_auc', data=self.rs_test_auc)
  194. hf.create_dataset('rs_test_loss', data=self.rs_test_loss)
  195. hf.create_dataset('clients_test_accs', data=self.clients_test_accs)
  196. # hf.create_dataset('rs_train_loss', data=self.rs_train_loss)
  197. def test_metrics(self, temp_model=None):
  198. """ A personalized evaluation scheme (test_acc's do not average based on num_samples) """
  199. test_accs, test_aucs, test_losses, test_nums = [], [], [], []
  200. for c in self.clients:
  201. test_acc, test_auc, test_loss, test_num = c.test_metrics(temp_model) # test_acc, test_num, test_auc
  202. test_accs.append(test_acc)
  203. test_aucs.append(test_auc)
  204. test_losses.append(test_loss)
  205. test_nums.append(test_num)
  206. ids = [c.id for c in self.clients]
  207. return ids, test_accs, test_aucs, test_losses, test_nums
  208. # evaluate selected clients
  209. def evaluate(self, temp_model=None, mode="personalized"):
  210. ids, test_accs, test_aucs, test_losses, test_nums = self.test_metrics(temp_model)
  211. self.clients_test_accs.append(copy.deepcopy(test_accs))
  212. if mode == "personalized":
  213. mean_test_acc, mean_test_auc, mean_test_loss = np.mean(test_accs), np.mean(test_aucs), np.mean(test_losses)
  214. elif mode == "global":
  215. mean_test_acc, mean_test_auc, mean_test_loss = np.average(test_accs, weights=test_nums), np.average(test_aucs, weights=test_nums), np.average(test_losses, weights=test_nums)
  216. else:
  217. raise NotImplementedError
  218. # compute domain means for
  219. if self.dataset.startswith("Office-home") and (mean_test_acc > self.best_mean_test_acc):
  220. self.best_mean_test_acc = mean_test_acc
  221. self.domain_mean_test_accs = np.mean(np.array(test_accs).reshape(4, -1), axis=1)
  222. self.best_mean_test_acc = max(mean_test_acc, self.best_mean_test_acc)
  223. self.rs_test_acc.append(mean_test_acc)
  224. self.rs_test_auc.append(mean_test_auc)
  225. self.rs_test_loss.append(mean_test_loss)
  226. print(f"==> test_loss: {mean_test_loss:.5f} | mean_test_accs: {mean_test_acc*100:.2f}% | best_acc: {self.best_mean_test_acc*100:.2f}%\n")
  227. def check_early_stopping(self, thresh=0.0):
  228. # Early stopping
  229. if thresh == 0.0:
  230. if (self.dataset == "cifar100"):
  231. thresh = 0.2
  232. elif (self.dataset == "cifar10"):
  233. thresh = 0.6
  234. elif (self.dataset.startswith("organamnist")):
  235. thresh = 0.8
  236. else:
  237. thresh = 0.23
  238. return (self.rs_test_acc[-1] < thresh) and (self.rs_test_acc[-2] < thresh) and (self.rs_test_acc[-3] < thresh) and (self.rs_test_acc[-4] < thresh)