123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- import copy
- import torch
- import argparse
- import os
- import time
- import warnings
- import numpy as np
- import torchvision
- from flcore.trainmodel.models import *
- warnings.simplefilter("ignore")
- torch.manual_seed(0)
- def run(args):
- model_str = args.model
- for i in range(args.prev, args.times):
- print(f"\n============= Running time: [{i+1}th/{args.times}] =============", flush=True)
- print("Creating server and clients ...")
- # Generate args.model
- if model_str == "cnn":
- if args.dataset == "mnist" or args.dataset.startswith("organamnist"):
- args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
- elif args.dataset.upper() == "CIFAR10" or args.dataset.upper() == "CIFAR100" or args.dataset.startswith("Office-home"):
- args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
- else:
- args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
- else:
- raise NotImplementedError
- # select algorithm
- if args.algorithm.startswith("Local"):
- from flcore.servers.serverlocal import Local
- server = Local(args, i)
- elif args.algorithm.startswith("FedAvg"):
- from flcore.servers.serveravg import FedAvg
- server = FedAvg(args, i)
- elif args.algorithm.startswith("FedDyn"):
- from flcore.servers.serverdyn import FedDyn
- server = FedDyn(args, i)
- elif args.algorithm.startswith("pFedMe"):
- from flcore.servers.serverpfedme import pFedMe
- server = pFedMe(args, i)
- elif args.algorithm.startswith("FedFomo"):
- from flcore.servers.serverfomo import FedFomo
- server = FedFomo(args, i)
- elif args.algorithm.startswith("APFL"):
- from flcore.servers.serverapfl import APFL
- server = APFL(args, i)
- elif args.algorithm.startswith("FedRep"):
- from flcore.servers.serverrep import FedRep
- args.predictor = copy.deepcopy(args.model.fc)
- args.model.fc = nn.Identity()
- args.model = LocalModel(args.model, args.predictor)
- server = FedRep(args, i)
- elif args.algorithm.startswith("LGFedAvg"):
- from flcore.servers.serverlgfedavg import LGFedAvg
- args.predictor = copy.deepcopy(args.model.fc)
- args.model.fc = nn.Identity()
- args.model = LocalModel(args.model, args.predictor)
- server = LGFedAvg(args, i)
- elif args.algorithm.startswith("FedPer"):
- from flcore.servers.serverper import FedPer
- args.predictor = copy.deepcopy(args.model.fc)
- args.model.fc = nn.Identity()
- args.model = LocalModel(args.model, args.predictor)
- server = FedPer(args, i)
- elif args.algorithm.startswith("PerAvg"):
- from flcore.servers.serverperavg import PerAvg
- server = PerAvg(args, i)
- elif args.algorithm.startswith("FedRoD"):
- from flcore.servers.serverrod import FedRoD
- args.predictor = copy.deepcopy(args.model.fc)
- args.model.fc = nn.Identity()
- args.model = LocalModel(args.model, args.predictor)
- server = FedRoD(args, i)
- elif args.algorithm.startswith("FedBABU"):
- args.predictor = copy.deepcopy(args.model.fc)
- args.model.fc = nn.Identity()
- args.model = LocalModel(args.model, args.predictor)
- from flcore.servers.serverbabu import FedBABU
- server = FedBABU(args, i)
- elif args.algorithm.startswith("PGFed"):
- from flcore.servers.serverpgfed import PGFed
- server = PGFed(args, i)
-
- else:
- raise NotImplementedError
- server.train()
- if args.dataset.startswith("Office-home") and args.times != 1:
- import logging
- m = server.domain_mean_test_accs
- logging.info(f"domains means and average:\t{m[0]:.6f}\t{m[1]:.6f}\t{m[2]:.6f}\t{m[3]:.6f}\t{server.best_mean_test_acc:.6f}")
- # # comment the above block and uncomment the following block for fine-tuning on new clients
- # if len(server.clients) == 100:
- # old_clients_num = 80
- # server.new_clients = server.clients[old_clients_num:]
- # server.clients = server.clients[:old_clients_num]
- # server.num_clients = old_clients_num
- # server.join_clients = int(old_clients_num * server.join_ratio)
- # if not args.algorithm.startswith("Local"):
- # server.train()
- # server.prepare_global_model()
- # n_epochs = 20
- # print(f"\n\n==> Training for new clients for {n_epochs} epochs")
- # server.train_new_clients(epochs=n_epochs)
- def get_args():
- parser = argparse.ArgumentParser()
- # general
- parser.add_argument('-go', "--goal", type=str, default="cnn",
- help="The goal for this experiment")
- parser.add_argument('-dev', "--device", type=str, default="cuda",
- choices=["cpu", "cuda"])
- parser.add_argument('-did', "--device_id", type=str, default="0")
- parser.add_argument('-data', "--dataset", type=str, default="cifar10",
- choices=["cifar10", "cifar100", "organaminist25", "organaminist50", "organaminist100", "Office-home20"])
- parser.add_argument('-nb', "--num_classes", type=int, default=10)
- parser.add_argument('-m', "--model", type=str, default="cnn")
- parser.add_argument('-p', "--predictor", type=str, default="cnn")
- parser.add_argument('-lbs', "--batch_size", type=int, default=10)
- parser.add_argument('-lr', "--local_learning_rate", type=float, default=0.005,
- help="Local learning rate")
- parser.add_argument('-gr', "--global_rounds", type=int, default=3)
- parser.add_argument('-ls', "--local_steps", type=int, default=5)
- parser.add_argument('-algo', "--algorithm", type=str, default="PGFed",
- choices=["Local", "FedAvg", "FedDyn", "pFedMe", "FedFomo", "APFL", "FedRep",
- "LGFedAvg", "FedPer", "PerAvg", "FedRoD", "FedBABU", "PGFed"])
- parser.add_argument('-jr', "--join_ratio", type=float, default=0.25,
- help="Ratio of clients per round")
- parser.add_argument('-nc', "--num_clients", type=int, default=25,
- help="Total number of clients")
- parser.add_argument('-pv', "--prev", type=int, default=0,
- help="Previous Running times")
- parser.add_argument('-t', "--times", type=int, default=1,
- help="Running times")
- parser.add_argument('-eg', "--eval_gap", type=int, default=1,
- help="Rounds gap for evaluation")
- # FL algorithms (multiple algs)
- parser.add_argument('-bt', "--beta", type=float, default=0.0,
- help="PGFed momentum, average moving parameter for pFedMe, Second learning rate of Per-FedAvg")
- parser.add_argument('-lam', "--lambdaa", type=float, default=1.0,
- help="PGFed learning rate for a_i, Regularization weight for pFedMe")
- parser.add_argument('-mu', "--mu", type=float, default=0,
- help="PGFed weight for aux risk, pFedMe weight")
- parser.add_argument('-K', "--K", type=int, default=5,
- help="Number of personalized training steps for pFedMe")
- parser.add_argument('-lrp', "--p_learning_rate", type=float, default=0.01,
- help="pFedMe personalized learning rate to caculate theta aproximately using K steps")
- # FedFomo
- parser.add_argument('-M', "--M", type=int, default=8,
- help="Server only sends M client models to one client at each round")
- # APFL
- parser.add_argument('-al', "--alpha", type=float, default=0.5)
- # FedRep
- parser.add_argument('-pls', "--plocal_steps", type=int, default=5)
- # FedBABU
- parser.add_argument('-fts', "--fine_tuning_steps", type=int, default=1)
- # save directories
- parser.add_argument("--hist_dir", type=str, default="../results/", help="dir path for output hist file")
- parser.add_argument("--log_dir", type=str, default="../logs/", help="dir path for log (main results) file")
- parser.add_argument("--ckpt_dir", type=str, default="../checkpoints/", help="dir path for checkpoints")
- args = parser.parse_args()
- return args
- if __name__ == "__main__":
- total_start = time.time()
- args = get_args()
- # os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
- if args.device == "cuda" and not torch.cuda.is_available():
- print("\ncuda is not avaiable.\n")
- args.device = "cpu"
- print("=" * 50)
- print("Algorithm: {}".format(args.algorithm))
- print("Local batch size: {}".format(args.batch_size))
- print("Local steps: {}".format(args.local_steps))
- print("Local learing rate: {}".format(args.local_learning_rate))
- print("Total number of clients: {}".format(args.num_clients))
- print("Clients join in each round: {}".format(args.join_ratio))
- print("Global rounds: {}".format(args.global_rounds))
- print("Running times: {}".format(args.times))
- print("Dataset: {}".format(args.dataset))
- print("Local model: {}".format(args.model))
- print("Using device: {}".format(args.device))
- if args.device == "cuda":
- print("Cuda device id: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
- print("=" * 50)
- run(args)
|