Browse Source

fix: fix id

shellmiao 1 year ago
parent
commit
cc550c43cc

+ 2 - 2
applications/fedssl/client.py

@@ -22,8 +22,8 @@ L2 = "l2"
 
 
 class FedSSLClient(BaseClient):
-    def __init__(self, id, cid, conf, train_data, test_data, device, sleep_time=0):
-        super(FedSSLClient, self).__init__(id, cid, conf, train_data, test_data, device, sleep_time)
+    def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
+        super(FedSSLClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
         self._local_model = None
         self.DAPU_predictor = LOCAL
         self.encoder_distance = 1

+ 2 - 2
applications/fedssl/client_with_pgfed.py

@@ -33,8 +33,8 @@ def model_dot_product(w1, w2, requires_grad=True):
     return dot_product
 
 class FedSSLWithPgFedClient(FedSSLClient):
-    def __init__(self, id, cid, conf, train_data, test_data, device, sleep_time=0):
-        super(FedSSLWithPgFedClient, self).__init__(id, cid, conf, train_data, test_data, device, sleep_time)
+    def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
+        super(FedSSLWithPgFedClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
         self._local_model = None
         self.DAPU_predictor = LOCAL
         self.encoder_distance = 1

+ 7 - 0
applications/fedssl/server_with_pgfed.py

@@ -38,11 +38,18 @@ class FedSSLWithPgFedServer(FedSSLServer):
         self.mean_grad = None
         self.convex_comb_grad = None
 
+    def set_clients(self, clients):
+        self._clients = clients
+        for i, _ in enumerate(self._clients):
+            self._clients[i].id = i
+
     def train(self):
         """Training process of federated learning."""
         self.print_("--- start training ---")
         print(f"\nJoin clients / total clients: {self.conf.server.clients_per_round} / {len(self._clients)}")
 
+
+
         self.selection(self._clients, self.conf.server.clients_per_round)
         self.grouping_for_distributed()
         self.compression()

+ 0 - 2
easyfl/client/base.py

@@ -74,7 +74,6 @@ class BaseClient(object):
         >>>         pass
     """
     def __init__(self,
-                 id,
                  cid,
                  conf,
                  train_data,
@@ -85,7 +84,6 @@ class BaseClient(object):
                  local_port=23000,
                  server_addr="localhost:22999",
                  tracker_addr="localhost:12666"):
-        self.id = id
         self.cid = cid
         self.conf = conf
         self.train_data = train_data

+ 1 - 2
easyfl/coordinator.py

@@ -162,8 +162,7 @@ class Coordinator(object):
         if self.conf.test_mode == TEST_IN_SERVER:
             client_test_data = None
 
-        self.clients = [self._client_class(i,
-                                           u,
+        self.clients = [self._client_class(u,
                                            self.conf.client,
                                            self.train_data,
                                            client_test_data,