main.py 9.2 KB


  1. import copy
  2. import torch
  3. import argparse
  4. import os
  5. import time
  6. import warnings
  7. import numpy as np
  8. import torchvision
  9. from flcore.trainmodel.models import *
  10. warnings.simplefilter("ignore")
  11. torch.manual_seed(0)
  12. def run(args):
  13. model_str = args.model
  14. for i in range(args.prev, args.times):
  15. print(f"\n============= Running time: [{i+1}th/{args.times}] =============", flush=True)
  16. print("Creating server and clients ...")
  17. # Generate args.model
  18. if model_str == "cnn":
  19. if args.dataset == "mnist" or args.dataset.startswith("organamnist"):
  20. args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
  21. elif args.dataset.upper() == "CIFAR10" or args.dataset.upper() == "CIFAR100" or args.dataset.startswith("Office-home"):
  22. args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
  23. else:
  24. args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
  25. else:
  26. raise NotImplementedError
  27. # select algorithm
  28. if args.algorithm.startswith("Local"):
  29. from flcore.servers.serverlocal import Local
  30. server = Local(args, i)
  31. elif args.algorithm.startswith("FedAvg"):
  32. from flcore.servers.serveravg import FedAvg
  33. server = FedAvg(args, i)
  34. elif args.algorithm.startswith("FedDyn"):
  35. from flcore.servers.serverdyn import FedDyn
  36. server = FedDyn(args, i)
  37. elif args.algorithm.startswith("pFedMe"):
  38. from flcore.servers.serverpfedme import pFedMe
  39. server = pFedMe(args, i)
  40. elif args.algorithm.startswith("FedFomo"):
  41. from flcore.servers.serverfomo import FedFomo
  42. server = FedFomo(args, i)
  43. elif args.algorithm.startswith("APFL"):
  44. from flcore.servers.serverapfl import APFL
  45. server = APFL(args, i)
  46. elif args.algorithm.startswith("FedRep"):
  47. from flcore.servers.serverrep import FedRep
  48. args.predictor = copy.deepcopy(args.model.fc)
  49. args.model.fc = nn.Identity()
  50. args.model = LocalModel(args.model, args.predictor)
  51. server = FedRep(args, i)
  52. elif args.algorithm.startswith("LGFedAvg"):
  53. from flcore.servers.serverlgfedavg import LGFedAvg
  54. args.predictor = copy.deepcopy(args.model.fc)
  55. args.model.fc = nn.Identity()
  56. args.model = LocalModel(args.model, args.predictor)
  57. server = LGFedAvg(args, i)
  58. elif args.algorithm.startswith("FedPer"):
  59. from flcore.servers.serverper import FedPer
  60. args.predictor = copy.deepcopy(args.model.fc)
  61. args.model.fc = nn.Identity()
  62. args.model = LocalModel(args.model, args.predictor)
  63. server = FedPer(args, i)
  64. elif args.algorithm.startswith("PerAvg"):
  65. from flcore.servers.serverperavg import PerAvg
  66. server = PerAvg(args, i)
  67. elif args.algorithm.startswith("FedRoD"):
  68. from flcore.servers.serverrod import FedRoD
  69. args.predictor = copy.deepcopy(args.model.fc)
  70. args.model.fc = nn.Identity()
  71. args.model = LocalModel(args.model, args.predictor)
  72. server = FedRoD(args, i)
  73. elif args.algorithm.startswith("FedBABU"):
  74. args.predictor = copy.deepcopy(args.model.fc)
  75. args.model.fc = nn.Identity()
  76. args.model = LocalModel(args.model, args.predictor)
  77. from flcore.servers.serverbabu import FedBABU
  78. server = FedBABU(args, i)
  79. elif args.algorithm.startswith("PGFed"):
  80. from flcore.servers.serverpgfed import PGFed
  81. server = PGFed(args, i)
  82. else:
  83. raise NotImplementedError
  84. server.train()
  85. if args.dataset.startswith("Office-home") and args.times != 1:
  86. import logging
  87. m = server.domain_mean_test_accs
  88. logging.info(f"domains means and average:\t{m[0]:.6f}\t{m[1]:.6f}\t{m[2]:.6f}\t{m[3]:.6f}\t{server.best_mean_test_acc:.6f}")
  89. # # comment the above block and uncomment the following block for fine-tuning on new clients
  90. # if len(server.clients) == 100:
  91. # old_clients_num = 80
  92. # server.new_clients = server.clients[old_clients_num:]
  93. # server.clients = server.clients[:old_clients_num]
  94. # server.num_clients = old_clients_num
  95. # server.join_clients = int(old_clients_num * server.join_ratio)
  96. # if not args.algorithm.startswith("Local"):
  97. # server.train()
  98. # server.prepare_global_model()
  99. # n_epochs = 20
  100. # print(f"\n\n==> Training for new clients for {n_epochs} epochs")
  101. # server.train_new_clients(epochs=n_epochs)
  102. def get_args():
  103. parser = argparse.ArgumentParser()
  104. # general
  105. parser.add_argument('-go', "--goal", type=str, default="cnn",
  106. help="The goal for this experiment")
  107. parser.add_argument('-dev', "--device", type=str, default="cuda",
  108. choices=["cpu", "cuda"])
  109. parser.add_argument('-did', "--device_id", type=str, default="0")
  110. parser.add_argument('-data', "--dataset", type=str, default="cifar10",
  111. choices=["cifar10", "cifar100", "organaminist25", "organaminist50", "organaminist100", "Office-home20"])
  112. parser.add_argument('-nb', "--num_classes", type=int, default=10)
  113. parser.add_argument('-m', "--model", type=str, default="cnn")
  114. parser.add_argument('-p', "--predictor", type=str, default="cnn")
  115. parser.add_argument('-lbs', "--batch_size", type=int, default=10)
  116. parser.add_argument('-lr', "--local_learning_rate", type=float, default=0.005,
  117. help="Local learning rate")
  118. parser.add_argument('-gr', "--global_rounds", type=int, default=3)
  119. parser.add_argument('-ls', "--local_steps", type=int, default=5)
  120. parser.add_argument('-algo', "--algorithm", type=str, default="PGFed")
  121. parser.add_argument('-jr', "--join_ratio", type=float, default=0.25,
  122. help="Ratio of clients per round")
  123. parser.add_argument('-nc', "--num_clients", type=int, default=25,
  124. help="Total number of clients")
  125. parser.add_argument('-pv', "--prev", type=int, default=0,
  126. help="Previous Running times")
  127. parser.add_argument('-t', "--times", type=int, default=1,
  128. help="Running times")
  129. parser.add_argument('-eg', "--eval_gap", type=int, default=1,
  130. help="Rounds gap for evaluation")
  131. # FL algorithms (multiple algs)
  132. parser.add_argument('-bt', "--beta", type=float, default=0.0,
  133. help="PGFed momentum, average moving parameter for pFedMe, Second learning rate of Per-FedAvg")
  134. parser.add_argument('-lam', "--lambdaa", type=float, default=1.0,
  135. help="PGFed learning rate for a_i, Regularization weight for pFedMe")
  136. parser.add_argument('-mu', "--mu", type=float, default=0,
  137. help="PGFed weight for aux risk, pFedMe weight")
  138. parser.add_argument('-K', "--K", type=int, default=5,
  139. help="Number of personalized training steps for pFedMe")
  140. parser.add_argument('-lrp', "--p_learning_rate", type=float, default=0.01,
  141. help="pFedMe personalized learning rate to caculate theta aproximately using K steps")
  142. # FedFomo
  143. parser.add_argument('-M', "--M", type=int, default=8,
  144. help="Server only sends M client models to one client at each round")
  145. # APFL
  146. parser.add_argument('-al', "--alpha", type=float, default=0.5)
  147. # FedRep
  148. parser.add_argument('-pls', "--plocal_steps", type=int, default=5)
  149. # FedBABU
  150. parser.add_argument('-fts', "--fine_tuning_steps", type=int, default=1)
  151. # save directories
  152. parser.add_argument("--hist_dir", type=str, default="../", help="dir path for output hist file")
  153. parser.add_argument("--log_dir", type=str, default="../", help="dir path for log (main results) file")
  154. parser.add_argument("--ckpt_dir", type=str, default="../", help="dir path for checkpoints")
  155. args = parser.parse_args()
  156. return args
  157. if __name__ == "__main__":
  158. total_start = time.time()
  159. args = get_args()
  160. # os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
  161. if args.device == "cuda" and not torch.cuda.is_available():
  162. print("\ncuda is not avaiable.\n")
  163. args.device = "cpu"
  164. print("=" * 50)
  165. print("Algorithm: {}".format(args.algorithm))
  166. print("Local batch size: {}".format(args.batch_size))
  167. print("Local steps: {}".format(args.local_steps))
  168. print("Local learing rate: {}".format(args.local_learning_rate))
  169. print("Total number of clients: {}".format(args.num_clients))
  170. print("Clients join in each round: {}".format(args.join_ratio))
  171. print("Global rounds: {}".format(args.global_rounds))
  172. print("Running times: {}".format(args.times))
  173. print("Dataset: {}".format(args.dataset))
  174. print("Local model: {}".format(args.model))
  175. print("Using device: {}".format(args.device))
  176. if args.device == "cuda":
  177. print("Cuda device id: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
  178. print("=" * 50)
  179. run(args)