|
@@ -64,7 +64,7 @@ class FedSSLWithPgFedServer(FedSSLServer):
|
|
self.track(metric.TRAIN_TIME, train_time)
|
|
self.track(metric.TRAIN_TIME, train_time)
|
|
|
|
|
|
def send_param(self):
|
|
def send_param(self):
|
|
- if self.alpha_mat!=None:
|
|
|
|
|
|
+ if self.alpha_mat==None:
|
|
self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
|
|
self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
|
|
for client in self.grouped_clients:
|
|
for client in self.grouped_clients:
|
|
client.a_i = self.alpha_mat[client.id]
|
|
client.a_i = self.alpha_mat[client.id]
|