Explorar o código

feat: add momentum update

shellmiao hai 1 ano
pai
achega
5f32b20b32
Modificáronse 1 ficheiros con 5 adicións e 1 borrados
  1. 5 1
      applications/fedssl/client_with_pgfed.py

+ 5 - 1
applications/fedssl/client_with_pgfed.py

@@ -140,4 +140,8 @@ class FedSSLWithPgFedClient(FedSSLClient):
         if self.prev_convex_comb_grad is None:
             self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
         else:
-            self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
+            self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
+
+    def set_model(self, old_m, new_m, momentum=0.0):
+        for p_old, p_new in zip(old_m.parameters(), new_m.parameters()):
+            p_old.data = (1 - momentum) * p_new.data.clone() + momentum * p_old.data.clone()