|
@@ -140,4 +140,8 @@ class FedSSLWithPgFedClient(FedSSLClient):
|
|
|
if self.prev_convex_comb_grad is None:
|
|
|
self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
|
|
|
else:
|
|
|
- self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
|
|
|
+ self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
|
|
|
+
|
|
|
+ def set_model(self, old_m, new_m, momentum=0.0):
|
|
|
+ for p_old, p_new in zip(old_m.parameters(), new_m.parameters()):
|
|
|
+ p_old.data = (1 - momentum) * p_new.data.clone() + momentum * p_old.data.clone()
|