|
@@ -24,8 +24,6 @@ L2 = "l2"
|
|
|
|
|
|
def model_dot_product(w1, w2, requires_grad=True):
|
|
|
""" Return the sum of squared difference between two models. """
|
|
|
- print(w1)
|
|
|
- print(w2)
|
|
|
dot_product = 0.0
|
|
|
for p1, p2 in zip(w1.parameters(), w2.parameters()):
|
|
|
if requires_grad:
|
|
@@ -156,14 +154,18 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
|
|
|
def set_prev_mean_grad(self, mean_grad):
|
|
|
if self.prev_mean_grad is None:
|
|
|
+ print("initing prev_mean_grad")
|
|
|
self.prev_mean_grad = copy.deepcopy(mean_grad)
|
|
|
else:
|
|
|
+ print("setting prev_mean_grad")
|
|
|
self.set_model(self.prev_mean_grad, mean_grad)
|
|
|
|
|
|
def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
|
|
|
if self.prev_convex_comb_grad is None:
|
|
|
+ print("initing prev_convex_comb_grad")
|
|
|
self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
|
|
|
else:
|
|
|
+ print("setting prev_convex_comb_grad")
|
|
|
self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
|
|
|
|
|
|
def set_model(self, old_m, new_m, momentum=0.0):
|