|
@@ -61,7 +61,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
data_count = 0 # delete later
|
|
data_count = 0 # delete later
|
|
batch_loss = []
|
|
batch_loss = []
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
- if data_count >= 500:
|
|
|
|
|
|
+ if data_count >= 50:
|
|
break
|
|
break
|
|
x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
data_count += x1.size(0)
|
|
data_count += x1.size(0)
|
|
@@ -97,7 +97,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
data_count = 0 # delete later
|
|
data_count = 0 # delete later
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
- if data_count >= 500:
|
|
|
|
|
|
+ if data_count >= 50:
|
|
break
|
|
break
|
|
x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
|
|
x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
|
|
data_count += x1.size(0)
|
|
data_count += x1.size(0)
|