소스 검색

feat: add momentum update

shellmiao 1 년 전
부모
커밋
5f32b20b32
1개의 변경된 파일5개의 추가작업 그리고 1개의 파일을 삭제
  1. 5 1
      applications/fedssl/client_with_pgfed.py

+ 5 - 1
applications/fedssl/client_with_pgfed.py

@@ -140,4 +140,8 @@ class FedSSLWithPgFedClient(FedSSLClient):
         if self.prev_convex_comb_grad is None:
             self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
         else:
-            self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
+            self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)
+
+    def set_model(self, old_m, new_m, momentum=0.0):
+        for p_old, p_new in zip(old_m.parameters(), new_m.parameters()):
+            p_old.data = (1 - momentum) * p_new.data.clone() + momentum * p_old.data.clone()