|
@@ -32,7 +32,7 @@ class FedSSLWithPgFedServer(FedSSLServer):
|
|
|
|
|
|
self.mu = 0
|
|
|
self.momentum = 0.0
|
|
|
- self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
|
|
|
+ self.alpha_mat = None
|
|
|
self.uploaded_grads = {}
|
|
|
self.loss_minuses = {}
|
|
|
self.mean_grad = None
|
|
@@ -57,6 +57,8 @@ class FedSSLWithPgFedServer(FedSSLServer):
|
|
|
self.track(metric.TRAIN_TIME, train_time)
|
|
|
|
|
|
def send_param(self):
|
|
|
+ if not self.alpha_mat:
|
|
|
+ self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
|
|
|
for client in self.grouped_clients:
|
|
|
client.a_i = self.alpha_mat[client.id]
|
|
|
if len(self.uploaded_grads) == 0:
|