Browse Source

feat: add momentum update

shellmiao 1 year ago
parent
commit
5f32b20b32
1 changed files with 5 additions and 1 deletions
  1. 5 1
      applications/fedssl/client_with_pgfed.py

+ 5 - 1
applications/fedssl/client_with_pgfed.py

@@ -140,4 +140,8 @@ class FedSSLWithPgFedClient(FedSSLClient):
         if self.prev_convex_comb_grad is None:
             self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
         else:
-            self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
+            self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
+
+    def set_model(self, old_m, new_m, momentum=0.0):
+        for p_old, p_new in zip(old_m.parameters(), new_m.parameters()):
+            p_old.data = (1 - momentum) * p_new.data.clone() + momentum * p_old.data.clone()