|
@@ -58,13 +58,9 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
self.model.to(device)
|
|
self.model.to(device)
|
|
old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
for i in range(conf.local_epoch):
|
|
for i in range(conf.local_epoch):
|
|
- data_count = 0 # delete later
|
|
|
|
batch_loss = []
|
|
batch_loss = []
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
- if data_count >= 50:
|
|
|
|
- break
|
|
|
|
x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
- data_count += x1.size(0)
|
|
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
if conf.model in [model.MoCo, model.MoCoV2]:
|
|
if conf.model in [model.MoCo, model.MoCoV2]:
|
|
@@ -95,12 +91,8 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
self.loss_minus = 0.0
|
|
self.loss_minus = 0.0
|
|
test_num = 0
|
|
test_num = 0
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
- data_count = 0 # delete later
|
|
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
for (batched_x1, batched_x2), _ in self.train_loader:
|
|
- if data_count >= 50:
|
|
|
|
- break
|
|
|
|
x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
|
|
x1, x2 = batched_x1.to(self.device), batched_x2.to(self.device)
|
|
- data_count += x1.size(0)
|
|
|
|
test_num += x1.size(0)
|
|
test_num += x1.size(0)
|
|
|
|
|
|
if conf.model in [model.MoCo, model.MoCoV2]:
|
|
if conf.model in [model.MoCo, model.MoCoV2]:
|
|
@@ -119,7 +111,7 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
if not self.latest_grad:
|
|
if not self.latest_grad:
|
|
self.latest_grad = copy.deepcopy(self.model)
|
|
self.latest_grad = copy.deepcopy(self.model)
|
|
|
|
|
|
- # delete later
|
|
|
|
|
|
+ # delete later: for test
|
|
# all_grads_none = True
|
|
# all_grads_none = True
|
|
# for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
|
|
# for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
|
|
# if p.grad is not None:
|
|
# if p.grad is not None:
|
|
@@ -154,19 +146,18 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
|
|
|
def set_prev_mean_grad(self, mean_grad):
|
|
def set_prev_mean_grad(self, mean_grad):
|
|
if self.prev_mean_grad is None:
|
|
if self.prev_mean_grad is None:
|
|
- print("initing prev_mean_grad")
|
|
|
|
- print(mean_grad)
|
|
|
|
|
|
+ print("Initing prev_mean_grad")
|
|
self.prev_mean_grad = copy.deepcopy(mean_grad)
|
|
self.prev_mean_grad = copy.deepcopy(mean_grad)
|
|
else:
|
|
else:
|
|
- print("setting prev_mean_grad")
|
|
|
|
|
|
+ print("Setting prev_mean_grad")
|
|
self.set_model(self.prev_mean_grad, mean_grad)
|
|
self.set_model(self.prev_mean_grad, mean_grad)
|
|
|
|
|
|
def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
|
|
def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
|
|
if self.prev_convex_comb_grad is None:
|
|
if self.prev_convex_comb_grad is None:
|
|
- print("initing prev_convex_comb_grad")
|
|
|
|
|
|
+ print("Initing prev_convex_comb_grad")
|
|
self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
|
|
self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
|
|
else:
|
|
else:
|
|
- print("setting prev_convex_comb_grad")
|
|
|
|
|
|
+ print("Setting prev_convex_comb_grad")
|
|
self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
|
|
self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
|
|
|
|
|
|
def set_model(self, old_m, new_m, momentum=0.0):
|
|
def set_model(self, old_m, new_m, momentum=0.0):
|