|
@@ -33,8 +33,8 @@ def model_dot_product(w1, w2, requires_grad=True):
|
|
|
return dot_product
|
|
|
|
|
|
class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
- def __init__(self, id, cid, conf, train_data, test_data, device, sleep_time=0):
|
|
|
- super(FedSSLWithPgFedClient, self).__init__(id, cid, conf, train_data, test_data, device, sleep_time)
|
|
|
+ def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
|
|
|
+ super(FedSSLWithPgFedClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
|
|
|
self._local_model = None
|
|
|
self.DAPU_predictor = LOCAL
|
|
|
self.encoder_distance = 1
|