main.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import copy
  2. import torch
  3. import argparse
  4. import os
  5. import time
  6. import warnings
  7. import numpy as np
  8. import torchvision
  9. from flcore.trainmodel.models import *
  10. warnings.simplefilter("ignore")
  11. torch.manual_seed(0)
  12. def run(args):
  13. model_str = args.model
  14. for i in range(args.prev, args.times):
  15. print(f"\n============= Running time: [{i+1}th/{args.times}] =============", flush=True)
  16. print("Creating server and clients ...")
  17. # Generate args.model
  18. if model_str == "cnn":
  19. if args.dataset == "mnist" or args.dataset.startswith("organamnist"):
  20. args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
  21. elif args.dataset.upper() == "CIFAR10" or args.dataset.upper() == "CIFAR100" or args.dataset.startswith("Office-home"):
  22. args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
  23. else:
  24. args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
  25. else:
  26. raise NotImplementedError
  27. # select algorithm
  28. if args.algorithm.startswith("Local"):
  29. from flcore.servers.serverlocal import Local
  30. server = Local(args, i)
  31. elif args.algorithm.startswith("FedAvg"):
  32. from flcore.servers.serveravg import FedAvg
  33. server = FedAvg(args, i)
  34. elif args.algorithm.startswith("FedDyn"):
  35. from flcore.servers.serverdyn import FedDyn
  36. server = FedDyn(args, i)
  37. elif args.algorithm.startswith("pFedMe"):
  38. from flcore.servers.serverpfedme import pFedMe
  39. server = pFedMe(args, i)
  40. elif args.algorithm.startswith("FedFomo"):
  41. from flcore.servers.serverfomo import FedFomo
  42. server = FedFomo(args, i)
  43. elif args.algorithm.startswith("APFL"):
  44. from flcore.servers.serverapfl import APFL
  45. server = APFL(args, i)
  46. elif args.algorithm.startswith("FedRep"):
  47. from flcore.servers.serverrep import FedRep
  48. args.predictor = copy.deepcopy(args.model.fc)
  49. args.model.fc = nn.Identity()
  50. args.model = LocalModel(args.model, args.predictor)
  51. server = FedRep(args, i)
  52. elif args.algorithm.startswith("LGFedAvg"):
  53. from flcore.servers.serverlgfedavg import LGFedAvg
  54. args.predictor = copy.deepcopy(args.model.fc)
  55. args.model.fc = nn.Identity()
  56. args.model = LocalModel(args.model, args.predictor)
  57. server = LGFedAvg(args, i)
  58. elif args.algorithm.startswith("FedPer"):
  59. from flcore.servers.serverper import FedPer
  60. args.predictor = copy.deepcopy(args.model.fc)
  61. args.model.fc = nn.Identity()
  62. args.model = LocalModel(args.model, args.predictor)
  63. server = FedPer(args, i)
  64. elif args.algorithm.startswith("PerAvg"):
  65. from flcore.servers.serverperavg import PerAvg
  66. server = PerAvg(args, i)
  67. elif args.algorithm.startswith("FedRoD"):
  68. from flcore.servers.serverrod import FedRoD
  69. args.predictor = copy.deepcopy(args.model.fc)
  70. args.model.fc = nn.Identity()
  71. args.model = LocalModel(args.model, args.predictor)
  72. server = FedRoD(args, i)
  73. elif args.algorithm.startswith("FedBABU"):
  74. args.predictor = copy.deepcopy(args.model.fc)
  75. args.model.fc = nn.Identity()
  76. args.model = LocalModel(args.model, args.predictor)
  77. from flcore.servers.serverbabu import FedBABU
  78. server = FedBABU(args, i)
  79. elif args.algorithm.startswith("PGFed"):
  80. from flcore.servers.serverpgfed import PGFed
  81. server = PGFed(args, i)
  82. else:
  83. raise NotImplementedError
  84. server.train()
  85. if args.dataset.startswith("Office-home") and args.times != 1:
  86. import logging
  87. m = server.domain_mean_test_accs
  88. logging.info(f"domains means and average:\t{m[0]:.6f}\t{m[1]:.6f}\t{m[2]:.6f}\t{m[3]:.6f}\t{server.best_mean_test_acc:.6f}")
  89. # # comment the above block and uncomment the following block for fine-tuning on new clients
  90. # if len(server.clients) == 100:
  91. # old_clients_num = 80
  92. # server.new_clients = server.clients[old_clients_num:]
  93. # server.clients = server.clients[:old_clients_num]
  94. # server.num_clients = old_clients_num
  95. # server.join_clients = int(old_clients_num * server.join_ratio)
  96. # if not args.algorithm.startswith("Local"):
  97. # server.train()
  98. # server.prepare_global_model()
  99. # n_epochs = 20
  100. # print(f"\n\n==> Training for new clients for {n_epochs} epochs")
  101. # server.train_new_clients(epochs=n_epochs)
  102. def get_args():
  103. parser = argparse.ArgumentParser()
  104. # general
  105. parser.add_argument('-go', "--goal", type=str, default="cnn",
  106. help="The goal for this experiment")
  107. parser.add_argument('-dev', "--device", type=str, default="cuda",
  108. choices=["cpu", "cuda"])
  109. parser.add_argument('-did', "--device_id", type=str, default="0")
  110. parser.add_argument('-data', "--dataset", type=str, default="cifar10",
  111. choices=["cifar10", "cifar100", "organaminist25", "organaminist50", "organaminist100", "Office-home20"])
  112. parser.add_argument('-nb', "--num_classes", type=int, default=10)
  113. parser.add_argument('-m', "--model", type=str, default="cnn")
  114. parser.add_argument('-p', "--predictor", type=str, default="cnn")
  115. parser.add_argument('-lbs', "--batch_size", type=int, default=10)
  116. parser.add_argument('-lr', "--local_learning_rate", type=float, default=0.005,
  117. help="Local learning rate")
  118. parser.add_argument('-gr', "--global_rounds", type=int, default=3)
  119. parser.add_argument('-ls', "--local_steps", type=int, default=5)
  120. parser.add_argument('-algo', "--algorithm", type=str, default="PGFed")
  121. parser.add_argument('-jr', "--join_ratio", type=float, default=0.25,
  122. help="Ratio of clients per round")
  123. parser.add_argument('-nc', "--num_clients", type=int, default=25,
  124. help="Total number of clients")
  125. parser.add_argument('-pv', "--prev", type=int, default=0,
  126. help="Previous Running times")
  127. parser.add_argument('-t', "--times", type=int, default=1,
  128. help="Running times")
  129. parser.add_argument('-eg', "--eval_gap", type=int, default=1,
  130. help="Rounds gap for evaluation")
  131. # FL algorithms (multiple algs)
  132. parser.add_argument('-bt', "--beta", type=float, default=0.0,
  133. help="PGFed momentum, average moving parameter for pFedMe, Second learning rate of Per-FedAvg")
  134. parser.add_argument('-lam', "--lambdaa", type=float, default=1.0,
  135. help="PGFed learning rate for a_i, Regularization weight for pFedMe")
  136. parser.add_argument('-mu', "--mu", type=float, default=0,
  137. help="PGFed weight for aux risk, pFedMe weight")
  138. parser.add_argument('-K', "--K", type=int, default=5,
  139. help="Number of personalized training steps for pFedMe")
  140. parser.add_argument('-lrp', "--p_learning_rate", type=float, default=0.01,
  141. help="pFedMe personalized learning rate to caculate theta aproximately using K steps")
  142. # FedFomo
  143. parser.add_argument('-M', "--M", type=int, default=8,
  144. help="Server only sends M client models to one client at each round")
  145. # APFL
  146. parser.add_argument('-al', "--alpha", type=float, default=0.5)
  147. # FedRep
  148. parser.add_argument('-pls', "--plocal_steps", type=int, default=5)
  149. # FedBABU
  150. parser.add_argument('-fts', "--fine_tuning_steps", type=int, default=1)
  151. # save directories
  152. parser.add_argument("--hist_dir", type=str, default="../", help="dir path for output hist file")
  153. parser.add_argument("--log_dir", type=str, default="../", help="dir path for log (main results) file")
  154. parser.add_argument("--ckpt_dir", type=str, default="../", help="dir path for checkpoints")
  155. args = parser.parse_args()
  156. return args
  157. if __name__ == "__main__":
  158. total_start = time.time()
  159. args = get_args()
  160. # os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
  161. if args.device == "cuda" and not torch.cuda.is_available():
  162. print("\ncuda is not avaiable.\n")
  163. args.device = "cpu"
  164. print("=" * 50)
  165. print("Algorithm: {}".format(args.algorithm))
  166. print("Local batch size: {}".format(args.batch_size))
  167. print("Local steps: {}".format(args.local_steps))
  168. print("Local learing rate: {}".format(args.local_learning_rate))
  169. print("Total number of clients: {}".format(args.num_clients))
  170. print("Clients join in each round: {}".format(args.join_ratio))
  171. print("Global rounds: {}".format(args.global_rounds))
  172. print("Running times: {}".format(args.times))
  173. print("Dataset: {}".format(args.dataset))
  174. print("Local model: {}".format(args.model))
  175. print("Using device: {}".format(args.device))
  176. if args.device == "cuda":
  177. print("Cuda device id: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
  178. print("=" * 50)
  179. run(args)