|
@@ -143,7 +143,9 @@ def get_args():
|
|
help="Local learning rate")
|
|
help="Local learning rate")
|
|
parser.add_argument('-gr', "--global_rounds", type=int, default=3)
|
|
parser.add_argument('-gr', "--global_rounds", type=int, default=3)
|
|
parser.add_argument('-ls', "--local_steps", type=int, default=5)
|
|
parser.add_argument('-ls', "--local_steps", type=int, default=5)
|
|
- parser.add_argument('-algo', "--algorithm", type=str, default="PGFed")
|
|
|
|
|
|
+ parser.add_argument('-algo', "--algorithm", type=str, default="PGFed",
|
|
|
|
+ choices=["Local", "FedAvg", "FedDyn", "pFedMe", "FedFomo", "APFL", "FedRep",
|
|
|
|
+ "LGFedAvg", "FedPer", "PerAvg", "FedRoD", "FedBABU", "PGFed"])
|
|
parser.add_argument('-jr', "--join_ratio", type=float, default=0.25,
|
|
parser.add_argument('-jr', "--join_ratio", type=float, default=0.25,
|
|
help="Ratio of clients per round")
|
|
help="Ratio of clients per round")
|
|
parser.add_argument('-nc', "--num_clients", type=int, default=25,
|
|
parser.add_argument('-nc', "--num_clients", type=int, default=25,
|
|
@@ -176,9 +178,9 @@ def get_args():
|
|
# FedBABU
|
|
# FedBABU
|
|
parser.add_argument('-fts', "--fine_tuning_steps", type=int, default=1)
|
|
parser.add_argument('-fts', "--fine_tuning_steps", type=int, default=1)
|
|
# save directories
|
|
# save directories
|
|
- parser.add_argument("--hist_dir", type=str, default="../", help="dir path for output hist file")
|
|
|
|
- parser.add_argument("--log_dir", type=str, default="../", help="dir path for log (main results) file")
|
|
|
|
- parser.add_argument("--ckpt_dir", type=str, default="../", help="dir path for checkpoints")
|
|
|
|
|
|
+ parser.add_argument("--hist_dir", type=str, default="../results/", help="dir path for output hist file")
|
|
|
|
+ parser.add_argument("--log_dir", type=str, default="../logs/", help="dir path for log (main results) file")
|
|
|
|
+ parser.add_argument("--ckpt_dir", type=str, default="../checkpoints/", help="dir path for checkpoints")
|
|
|
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
return args
|
|
return args
|