|
@@ -58,6 +58,8 @@ def run():
|
|
|
parser.add_argument('--gpu', default=0, type=int)
|
|
|
parser.add_argument('--run_count', default=0, type=int)
|
|
|
|
|
|
+ parser.add_argument('--use_pgfed', default=False, type=bool)
|
|
|
+
|
|
|
args = parser.parse_args()
|
|
|
print("arguments: ", args)
|
|
|
|
|
@@ -164,8 +166,13 @@ def run():
|
|
|
|
|
|
model = get_model(args.model, args.encoder_network, args.predictor_network)
|
|
|
easyfl.register_model(model)
|
|
|
- easyfl.register_client(FedSSLWithPgFedClient)
|
|
|
- easyfl.register_server(FedSSLWithPgFedServer)
|
|
|
+ if args.use_pgfed:
|
|
|
+ easyfl.register_client(FedSSLWithPgFedClient)
|
|
|
+ easyfl.register_server(FedSSLWithPgFedServer)
|
|
|
+ else:
|
|
|
+ easyfl.register_client(FedSSLClient)
|
|
|
+ easyfl.register_server(FedSSLServer)
|
|
|
+
|
|
|
easyfl.init(config, init_all=True)
|
|
|
easyfl.run()
|
|
|
|