Browse Source

feat: add pgfed

shellmiao 1 year ago
parent
commit
6c18662332

+ 142 - 0
applications/fedssl/client_with_pgfed.py

@@ -0,0 +1,142 @@
+import copy
+import gc
+import logging
+import time
+from collections import Counter
+
+import numpy as np
+import torch
+import torch._utils
+import torch.nn as nn
+import torch.nn.functional as F
+
+import model
+import utils
+from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
+from easyfl.client.base import BaseClient
+from easyfl.distributed.distributed import CPU
+
+from client import FedSSLClient
+
+logger = logging.getLogger(__name__)
+
+L2 = "l2"
+
+def model_dot_product(w1, w2, requires_grad=True):
+    """ Return the sum of squared difference between two models. """
+    dot_product = 0.0
+    for p1, p2 in zip(w1.parameters(), w2.parameters()):
+        if requires_grad:
+            dot_product += torch.sum(p1 * p2)
+        else:
+            dot_product += torch.sum(p1.data * p2.data)
+    return dot_product
+
+class FedSSLWithPgFedClient(FedSSLClient):
+    def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
+        super(FedSSLWithPgFedClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
+        self._local_model = None
+        self.DAPU_predictor = LOCAL
+        self.encoder_distance = 1
+        self.encoder_distances = []
+        self.previous_trained_round = -1
+        self.weight_scaler = None
+
+        self.latest_grad = copy.deepcopy(self.model)
+        self.lambdaa = 1.0 # PGFed learning rate for a_i, Regularization weight for pFedMe
+        self.prev_loss_minuses = {}
+        self.prev_mean_grad = None
+        self.prev_convex_comb_grad = None
+        self.a_i = None
+
+    def train(self, conf, device=CPU):
+        start_time = time.time()
+        loss_fn, optimizer = self.pretrain_setup(conf, device)
+        if conf.model in [model.MoCo, model.MoCoV2]:
+            self.model.reset_key_encoder()
+        self.train_loss = []
+        self.model.to(device)
+        old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
+        for i in range(conf.local_epoch):
+            batch_loss = []
+            for (batched_x1, batched_x2), _ in self.train_loader:
+                x1, x2 = batched_x1.to(device), batched_x2.to(device)
+                optimizer.zero_grad()
+
+                if conf.model in [model.MoCo, model.MoCoV2]:
+                    loss = self.model(x1, x2, device)
+                elif conf.model == model.SimCLR:
+                    images = torch.cat((x1, x2), dim=0)
+                    features = self.model(images)
+                    logits, labels = self.info_nce_loss(features)
+                    loss = loss_fn(logits, labels)
+                else:
+                    loss = self.model(x1, x2)
+
+                loss.backward()
+                if self.prev_convex_comb_grad is not None:
+                    for p_m, p_prev_conv in zip(self.model.parameters(), self.prev_convex_comb_grad.parameters()):
+                        p_m.grad.data += p_prev_conv.data
+                    dot_prod = model_dot_product(self.model, self.prev_mean_grad, requires_grad=False)
+                    self.update_a_i(dot_prod)
+                optimizer.step()
+                batch_loss.append(loss.item())
+
+                if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
+                    self.model.update_moving_average()
+
+            current_epoch_loss = sum(batch_loss) / len(batch_loss)
+            self.train_loss.append(float(current_epoch_loss))
+
+        # get loss_minus and latest_grad
+        self.loss_minus = 0.0
+        test_num = 0
+        self.optimizer.zero_grad()
+        for i, (x, y) in enumerate(self.train_loader):
+            if type(x) == type([]):
+                x[0] = x[0].to(self.device)
+            else:
+                x = x.to(self.device)
+            y = y.to(self.device)
+            test_num += y.shape[0]
+            output = self.model(x)
+            loss = self.criterion(output, y)
+            self.loss_minus += (loss * y.shape[0]).item()
+            loss.backward()
+
+        self.loss_minus /= test_num
+        for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
+            p_l.data = p.grad.data.clone() / len(self.train_loader)
+        self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
+
+        self.train_time = time.time() - start_time
+
+        # store trained model locally
+        self._local_model = copy.deepcopy(self.model).cpu()
+        self.previous_trained_round = conf.round_id
+        if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
+            new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
+            self.encoder_distance = self._calculate_divergence(old_model, new_model)
+            self.encoder_distances.append(self.encoder_distance.item())
+            self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
+            if self.conf.auto_scaler == 'y' and self.conf.random_selection:
+                self._calculate_weight_scaler()
+            if (conf.round_id + 1) % 100 == 0:
+                logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
+
+    def update_a_i(self, dot_prod):
+        for clt_j, mu_loss_minus in self.prev_loss_minuses.items():
+            self.a_i[clt_j] -= self.lambdaa * (mu_loss_minus + dot_prod)
+            self.a_i[clt_j] = max(self.a_i[clt_j], 0.0)
+
+    def set_prev_mean_grad(self, mean_grad):
+        if self.prev_mean_grad is None:
+            self.prev_mean_grad = copy.deepcopy(mean_grad)
+        else:
+            self.set_model(self.prev_mean_grad, mean_grad)
+
+    def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
+        if self.prev_convex_comb_grad is None:
+            self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
+        else:
+            self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)

+ 119 - 0
applications/fedssl/server_with_pgfed.py

@@ -0,0 +1,119 @@
+import copy
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from torchvision import datasets
+
+import time
+import model
+import utils
+from communication import TARGET
+from easyfl.datasets.data import CIFAR100
+from easyfl.distributed import reduce_models
+from easyfl.distributed.distributed import CPU
+from easyfl.server import strategies
+from easyfl.server.base import BaseServer, MODEL, DATA_SIZE
+from easyfl.tracking import metric
+from easyfl.protocol import codec
+from knn_monitor import knn_monitor
+
+from server import FedSSLServer
+
+logger = logging.getLogger(__name__)
+
+
+class FedSSLWithPgFedServer(FedSSLServer):
+    def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
+        super(FedSSLWithPgFedServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
+        self.train_loader = None
+        self.test_loader = None
+
+        self.mu = 0
+        self.momentum = 0.0
+        self.alpha_mat = (torch.ones((len(self._clients), len(self._clients))) / self.conf.server.clients_per_round).to(self.conf.device)
+        self.uploaded_grads = {}
+        self.loss_minuses = {}
+        self.mean_grad = None
+        self.convex_comb_grad = None
+
+    def train(self):
+        """Training process of federated learning."""
+        self.print_("--- start training ---")
+        print(f"\nJoin clients / total clients: {self.conf.server.clients_per_round} / {len(self._clients)}")
+
+        self.selection(self._clients, self.conf.server.clients_per_round)
+        self.grouping_for_distributed()
+        self.compression()
+
+        begin_train_time = time.time()
+        self.send_param()
+        self.distribution_to_train()
+        self.aggregation()
+
+        train_time = time.time() - begin_train_time
+        self.print_("Server train time: {}".format(train_time))
+        self.track(metric.TRAIN_TIME, train_time)
+    
+    def send_param(self):
+        for client in self.grouped_clients:
+            client.a_i = self.alpha_mat[client.id]
+        if len(self.uploaded_grads) == 0:
+            return
+        self.convex_comb_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
+        for client in self.grouped_clients:
+            client.set_prev_mean_grad(self.mean_grad)
+            mu_a_i = self.alpha_mat[client.id] * self.mu
+            grads, weights = [], []
+            for clt_idx, grad in self.uploaded_grads.items():
+                weights.append(mu_a_i[clt_idx])
+                grads.append(grad)
+            self.model_weighted_sum(self.convex_comb_grad, grads, weights)
+            client.set_prev_convex_comb_grad(self.convex_comb_grad, momentum=self.momentum)
+            client.prev_loss_minuses = copy.deepcopy(self.loss_minuses)
+    
+    def distribution_to_train_locally(self):
+        """Conduct training sequentially for selected clients in the group."""
+        uploaded_models = {}
+        uploaded_weights = {}
+        uploaded_metrics = []
+        for client in self.grouped_clients:
+            # Update client config before training
+            self.conf.client.task_id = self.conf.task_id
+            self.conf.client.round_id = self._current_round
+
+            uploaded_request = client.run_train(self._compressed_model, self.conf.client)
+            uploaded_content = uploaded_request.content
+
+            model = self.decompression(codec.unmarshal(uploaded_content.data))
+            uploaded_models[client.cid] = model
+            uploaded_weights[client.cid] = uploaded_content.data_size
+            self.receive_param(client)
+
+            uploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric))
+
+        self.set_client_uploads_train(uploaded_models, uploaded_weights, uploaded_metrics)
+
+    def receive_param(self,client):
+        self.uploaded_ids = []
+        self.uploaded_grads = {}
+        self.loss_minuses = {}
+        for client in self.selected_clients:
+            self.uploaded_ids.append(client.id)
+            self.alpha_mat[client.id] = client.a_i
+            self.uploaded_grads[client.id] = client.latest_grad
+            self.loss_minuses[client.id] = client.loss_minus * self.mu
+    
+    def get_mean_grad(self):
+        w = self.mu/self.conf.server.clients_per_round
+        weights = [w for _ in range(self.conf.server.clients_per_round)]
+        self.mean_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
+        self.model_weighted_sum(self.mean_grad, list(self.uploaded_grads.values()), weights)
+
+    def model_weighted_sum(self, model, models, weights):
+        for p_m in model.parameters():
+            p_m.data.zero_()
+        for w, m_i in zip(weights, models):
+            for p_m, p_i in zip(model.parameters(), m_i.parameters()):
+                p_m.data += p_i.data.clone() * w

+ 2 - 0
easyfl/client/base.py

@@ -74,6 +74,7 @@ class BaseClient(object):
         >>>         pass
     """
     def __init__(self,
+                 id,
                  cid,
                  conf,
                  train_data,
@@ -84,6 +85,7 @@ class BaseClient(object):
                  local_port=23000,
                  server_addr="localhost:22999",
                  tracker_addr="localhost:12666"):
+        self.id = id
         self.cid = cid
         self.conf = conf
         self.train_data = train_data

+ 2 - 1
easyfl/coordinator.py

@@ -162,7 +162,8 @@ class Coordinator(object):
         if self.conf.test_mode == TEST_IN_SERVER:
             client_test_data = None
 
-        self.clients = [self._client_class(u,
+        self.clients = [self._client_class(i,
+                                           u,
                                            self.conf.client,
                                            self.train_data,
                                            client_test_data,

+ 1 - 0
easyfl/server/base.py

@@ -290,6 +290,7 @@ class BaseServer(object):
             }
             return test_results
 
+    # Client selection
     def selection(self, clients, clients_per_round):
         """Select a fraction of total clients for training.
         Two selection strategies are implemented: 1. random selection; 2. select the first K clients.