|
@@ -0,0 +1,212 @@
|
|
|
+import copy
|
|
|
+import torch
|
|
|
+import argparse
|
|
|
+import os
|
|
|
+import time
|
|
|
+import warnings
|
|
|
+import numpy as np
|
|
|
+import torchvision
|
|
|
+
|
|
|
+from flcore.trainmodel.models import *
|
|
|
+
|
|
|
+warnings.simplefilter("ignore")
|
|
|
+torch.manual_seed(0)
|
|
|
+
|
|
|
+def run(args):
|
|
|
+ model_str = args.model
|
|
|
+ for i in range(args.prev, args.times):
|
|
|
+ print(f"\n============= Running time: [{i+1}th/{args.times}] =============", flush=True)
|
|
|
+ print("Creating server and clients ...")
|
|
|
+
|
|
|
+ # Generate args.model
|
|
|
+ if model_str == "cnn":
|
|
|
+ if args.dataset == "mnist" or args.dataset.startswith("organamnist"):
|
|
|
+ args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
|
|
|
+ elif args.dataset.upper() == "CIFAR10" or args.dataset.upper() == "CIFAR100" or args.dataset.startswith("Office-home"):
|
|
|
+ args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
|
|
|
+ else:
|
|
|
+ args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+ # select algorithm
|
|
|
+ if args.algorithm.startswith("Local"):
|
|
|
+ from flcore.servers.serverlocal import Local
|
|
|
+ server = Local(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("FedAvg"):
|
|
|
+ from flcore.servers.serveravg import FedAvg
|
|
|
+ server = FedAvg(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("FedDyn"):
|
|
|
+ from flcore.servers.serverdyn import FedDyn
|
|
|
+ server = FedDyn(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("pFedMe"):
|
|
|
+ from flcore.servers.serverpfedme import pFedMe
|
|
|
+ server = pFedMe(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("FedFomo"):
|
|
|
+ from flcore.servers.serverfomo import FedFomo
|
|
|
+ server = FedFomo(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("APFL"):
|
|
|
+ from flcore.servers.serverapfl import APFL
|
|
|
+ server = APFL(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("FedRep"):
|
|
|
+ from flcore.servers.serverrep import FedRep
|
|
|
+ args.predictor = copy.deepcopy(args.model.fc)
|
|
|
+ args.model.fc = nn.Identity()
|
|
|
+ args.model = LocalModel(args.model, args.predictor)
|
|
|
+ server = FedRep(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("LGFedAvg"):
|
|
|
+ from flcore.servers.serverlgfedavg import LGFedAvg
|
|
|
+ args.predictor = copy.deepcopy(args.model.fc)
|
|
|
+ args.model.fc = nn.Identity()
|
|
|
+ args.model = LocalModel(args.model, args.predictor)
|
|
|
+ server = LGFedAvg(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("FedPer"):
|
|
|
+ from flcore.servers.serverper import FedPer
|
|
|
+ args.predictor = copy.deepcopy(args.model.fc)
|
|
|
+ args.model.fc = nn.Identity()
|
|
|
+ args.model = LocalModel(args.model, args.predictor)
|
|
|
+ server = FedPer(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("PerAvg"):
|
|
|
+ from flcore.servers.serverperavg import PerAvg
|
|
|
+ server = PerAvg(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("FedRoD"):
|
|
|
+ from flcore.servers.serverrod import FedRoD
|
|
|
+ args.predictor = copy.deepcopy(args.model.fc)
|
|
|
+ args.model.fc = nn.Identity()
|
|
|
+ args.model = LocalModel(args.model, args.predictor)
|
|
|
+ server = FedRoD(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("FedBABU"):
|
|
|
+ args.predictor = copy.deepcopy(args.model.fc)
|
|
|
+ args.model.fc = nn.Identity()
|
|
|
+ args.model = LocalModel(args.model, args.predictor)
|
|
|
+ from flcore.servers.serverbabu import FedBABU
|
|
|
+ server = FedBABU(args, i)
|
|
|
+
|
|
|
+ elif args.algorithm.startswith("PGFed"):
|
|
|
+ from flcore.servers.serverpgfed import PGFed
|
|
|
+ server = PGFed(args, i)
|
|
|
+
|
|
|
+ else:
|
|
|
+ raise NotImplementedError
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ server.train()
|
|
|
+ if args.dataset.startswith("Office-home") and args.times != 1:
|
|
|
+ import logging
|
|
|
+ m = server.domain_mean_test_accs
|
|
|
+ logging.info(f"domains means and average:\t{m[0]:.6f}\t{m[1]:.6f}\t{m[2]:.6f}\t{m[3]:.6f}\t{server.best_mean_test_acc:.6f}")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ # # comment the above block and uncomment the following block for fine-tuning on new clients
|
|
|
+ # if len(server.clients) == 100:
|
|
|
+ # old_clients_num = 80
|
|
|
+ # server.new_clients = server.clients[old_clients_num:]
|
|
|
+ # server.clients = server.clients[:old_clients_num]
|
|
|
+ # server.num_clients = old_clients_num
|
|
|
+ # server.join_clients = int(old_clients_num * server.join_ratio)
|
|
|
+ # if not args.algorithm.startswith("Local"):
|
|
|
+ # server.train()
|
|
|
+ # server.prepare_global_model()
|
|
|
+ # n_epochs = 20
|
|
|
+ # print(f"\n\n==> Training for new clients for {n_epochs} epochs")
|
|
|
+ # server.train_new_clients(epochs=n_epochs)
|
|
|
+
|
|
|
+def get_args():
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ # general
|
|
|
+ parser.add_argument('-go', "--goal", type=str, default="cnn",
|
|
|
+ help="The goal for this experiment")
|
|
|
+ parser.add_argument('-dev', "--device", type=str, default="cuda",
|
|
|
+ choices=["cpu", "cuda"])
|
|
|
+ parser.add_argument('-did', "--device_id", type=str, default="0")
|
|
|
+ parser.add_argument('-data', "--dataset", type=str, default="cifar10",
|
|
|
+ choices=["cifar10", "cifar100", "organaminist25", "organaminist50", "organaminist100", "Office-home20"])
|
|
|
+ parser.add_argument('-nb', "--num_classes", type=int, default=10)
|
|
|
+ parser.add_argument('-m', "--model", type=str, default="cnn")
|
|
|
+ parser.add_argument('-p', "--predictor", type=str, default="cnn")
|
|
|
+ parser.add_argument('-lbs', "--batch_size", type=int, default=10)
|
|
|
+ parser.add_argument('-lr', "--local_learning_rate", type=float, default=0.005,
|
|
|
+ help="Local learning rate")
|
|
|
+ parser.add_argument('-gr', "--global_rounds", type=int, default=3)
|
|
|
+ parser.add_argument('-ls', "--local_steps", type=int, default=5)
|
|
|
+ parser.add_argument('-algo', "--algorithm", type=str, default="PGFed")
|
|
|
+ parser.add_argument('-jr', "--join_ratio", type=float, default=0.25,
|
|
|
+ help="Ratio of clients per round")
|
|
|
+ parser.add_argument('-nc', "--num_clients", type=int, default=25,
|
|
|
+ help="Total number of clients")
|
|
|
+ parser.add_argument('-pv', "--prev", type=int, default=0,
|
|
|
+ help="Previous Running times")
|
|
|
+ parser.add_argument('-t', "--times", type=int, default=1,
|
|
|
+ help="Running times")
|
|
|
+ parser.add_argument('-eg', "--eval_gap", type=int, default=1,
|
|
|
+ help="Rounds gap for evaluation")
|
|
|
+
|
|
|
+ # FL algorithms (multiple algs)
|
|
|
+ parser.add_argument('-bt', "--beta", type=float, default=0.0,
|
|
|
+ help="PGFed momentum, average moving parameter for pFedMe, Second learning rate of Per-FedAvg")
|
|
|
+ parser.add_argument('-lam', "--lambdaa", type=float, default=1.0,
|
|
|
+ help="PGFed learning rate for a_i, Regularization weight for pFedMe")
|
|
|
+ parser.add_argument('-mu', "--mu", type=float, default=0,
|
|
|
+ help="PGFed weight for aux risk, pFedMe weight")
|
|
|
+ parser.add_argument('-K', "--K", type=int, default=5,
|
|
|
+ help="Number of personalized training steps for pFedMe")
|
|
|
+ parser.add_argument('-lrp', "--p_learning_rate", type=float, default=0.01,
|
|
|
+ help="pFedMe personalized learning rate to caculate theta aproximately using K steps")
|
|
|
+ # FedFomo
|
|
|
+ parser.add_argument('-M', "--M", type=int, default=8,
|
|
|
+ help="Server only sends M client models to one client at each round")
|
|
|
+ # APFL
|
|
|
+ parser.add_argument('-al', "--alpha", type=float, default=0.5)
|
|
|
+ # FedRep
|
|
|
+ parser.add_argument('-pls', "--plocal_steps", type=int, default=5)
|
|
|
+ # FedBABU
|
|
|
+ parser.add_argument('-fts', "--fine_tuning_steps", type=int, default=1)
|
|
|
+ # save directories
|
|
|
+ parser.add_argument("--hist_dir", type=str, default="../", help="dir path for output hist file")
|
|
|
+ parser.add_argument("--log_dir", type=str, default="../", help="dir path for log (main results) file")
|
|
|
+ parser.add_argument("--ckpt_dir", type=str, default="../", help="dir path for checkpoints")
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+ return args
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ total_start = time.time()
|
|
|
+ args = get_args()
|
|
|
+ # os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
|
|
|
+
|
|
|
+ if args.device == "cuda" and not torch.cuda.is_available():
|
|
|
+ print("\ncuda is not avaiable.\n")
|
|
|
+ args.device = "cpu"
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ print("Algorithm: {}".format(args.algorithm))
|
|
|
+ print("Local batch size: {}".format(args.batch_size))
|
|
|
+ print("Local steps: {}".format(args.local_steps))
|
|
|
+ print("Local learing rate: {}".format(args.local_learning_rate))
|
|
|
+ print("Total number of clients: {}".format(args.num_clients))
|
|
|
+ print("Clients join in each round: {}".format(args.join_ratio))
|
|
|
+ print("Global rounds: {}".format(args.global_rounds))
|
|
|
+ print("Running times: {}".format(args.times))
|
|
|
+ print("Dataset: {}".format(args.dataset))
|
|
|
+ print("Local model: {}".format(args.model))
|
|
|
+ print("Using device: {}".format(args.device))
|
|
|
+
|
|
|
+ if args.device == "cuda":
|
|
|
+ print("Cuda device id: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ run(args)
|