Explorar el Código

feat: add kde

Shellmiao hace 5 meses
padre
commit
e3f8a68ac9
Se han modificado 2 ficheros con 77 adiciones y 6 borrados
  1. 18 4
      applications/fedssl/client.py
  2. 59 2
      applications/fedssl/utils.py

+ 18 - 4
applications/fedssl/client.py

@@ -16,7 +16,7 @@ from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA,
 from easyfl.client.base import BaseClient
 from easyfl.distributed.distributed import CPU
 
-logger = logging.getLogger(__name__)
+logger = logging.getLogger(__name__)  # 创建日志记录器
 
 L2 = "l2"
 
@@ -30,6 +30,8 @@ class FedSSLClient(BaseClient):
         self.encoder_distances = []
         self.previous_trained_round = -1
         self.weight_scaler = None
+        self.batch_id_privacy = []
+        self.batch_id = []
 
     def decompression(self):
         if self.model is None:
@@ -83,7 +85,7 @@ class FedSSLClient(BaseClient):
                 weight = min(1, self.weight_scaler * weight)
                 weight = 1 - weight
                 self.compressed_model = self.compressed_model.cpu()
-                online_encoder = self.compressed_model.online_encoder
+                online_encoder = self.compressed_model.online_encoder  # 聚合的在线编码器
                 target_encoder = self._local_model.target_encoder
                 ema_updater = model.EMA(weight)
                 model.update_moving_average(ema_updater, online_encoder, self._local_model.online_encoder)
@@ -198,11 +200,14 @@ class FedSSLClient(BaseClient):
         if conf.model in [model.MoCo, model.MoCoV2]:
             self.model.reset_key_encoder()
         self.train_loss = []
+
         self.model.to(device)
         old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
         for i in range(conf.local_epoch):
             batch_loss = []
-            for (batched_x1, batched_x2), _ in self.train_loader:
+            for batch_index, ((batched_x1, batched_x2), _) in enumerate(self.train_loader):
+                # if conf.round_id > 0 and batch_index not in self.batch_id:
+                #     continue
                 x1, x2 = batched_x1.to(device), batched_x2.to(device)
                 optimizer.zero_grad()
 
@@ -222,9 +227,17 @@ class FedSSLClient(BaseClient):
 
                 if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
                     self.model.update_moving_average()
-
             current_epoch_loss = sum(batch_loss) / len(batch_loss)
             self.train_loss.append(float(current_epoch_loss))
+            if conf.round_id == 0 and i == 4:
+                all_mi = utils.information(self.model.online_encoder, self.train_loader, device, self.cid)
+                # 使用enumerate获取元素和对应的序号,并使用sorted进行排序
+                sorted_ls_with_indices = sorted(enumerate(all_mi), key=lambda x: x[1])
+                # 提取排序后的序号
+                sorted_indices = [index for index, element in sorted_ls_with_indices]
+                self.batch_id_privacy = sorted_indices  # ---------------batch_privacy--------
+                self.batch_id = self.batch_id_privacy[:int(len(self.batch_id_privacy) * 0.9)]
+
         self.train_time = time.time() - start_time
 
         # store trained model locally
@@ -252,6 +265,7 @@ class FedSSLClient(BaseClient):
         old_dict = old_model.state_dict()
         new_dict = new_model.state_dict()
         for name, param in old_model.named_parameters():
+            print(name)
             if 'conv' in name and 'weight' in name:
                 total_distance += self._calculate_distance(old_dict[name].detach().clone().view(1, -1),
                                                            new_dict[name].detach().clone().view(1, -1),

+ 59 - 2
applications/fedssl/utils.py

@@ -1,8 +1,9 @@
 import torch
-
 import transform
 from model import SimSiam, MoCo
-
+from sklearn.neighbors import KernelDensity
+from tqdm import tqdm
+import numpy as np
 
 def get_transformation(model):
     if model == SimSiam:
@@ -31,3 +32,59 @@ def normalize(arr):
     if diff == 0:
         return arr
     return [(x - minn) / diff for x in arr]
+
+""" 添加的内容 """
+# 自定义钩子函数
+def hook_fn(module, input, output):
+    module.output = output
+
+def kernel_mi(x, y, band=0.5):
+    # 使用KDE估计变量x的概率密度函数
+    kde_x = KernelDensity(kernel='gaussian', bandwidth=band)
+    kde_x.fit(x)
+
+    # 使用KDE估计变量y的概率密度函数
+    kde_y = KernelDensity(kernel='gaussian', bandwidth=band)
+    kde_y.fit(y)
+
+    # 使用估计的概率密度函数计算联合概率密度函数
+    xy = np.column_stack([x, y])
+    kde_xy = KernelDensity(kernel='gaussian', bandwidth=band)
+    kde_xy.fit(xy)
+
+    # 计算互信息
+    log_p_xy = kde_xy.score_samples(xy)
+    log_p_x = kde_x.score_samples(x)
+    log_p_y = kde_y.score_samples(y)
+    mi = (log_p_xy - log_p_x - log_p_y).mean()  # 假设样本与样本之间是独立同分布的
+
+    return mi
+
+# the model is model.online_encoder
+def information(model, train_loader, device, cid):
+    # 评估模式
+    model.eval()
+    # all batch 互信息
+    all_mi = []
+    # 不会进行计算梯度,也不会进行反向传播
+    with torch.no_grad():
+        for batch_index, ((batched_x1, batched_x2), _) in enumerate(train_loader):
+            # 部署到device上
+            data = batched_x1.to(device)
+            # the feature batch*2048
+            feature = model(data)
+            batch_feature = feature.detach().clone()
+            state_dict = model.state_dict()
+            # 权重的处理
+            for name, param in model.named_parameters():
+                if 'fc.net.3.weight' in name:
+                    # feature_weight = state_dict[name].detach().clone()
+                    # feature_weight = feature_weight.reshape(-1, batch_feature.size(1))
+                    feature_weight = state_dict[name].detach().clone()
+                    feature_weight = feature_weight.t()
+            feature_weight = feature_weight[:data.size(0), :]
+            mi = kernel_mi(batch_feature.cpu().numpy(), feature_weight.cpu().numpy())
+            print("client={}, epoch=4 batch={}, the mi = {}".format(cid, batch_index, mi))
+            all_mi.append(mi)
+    return all_mi
+