Переглянути джерело

feat: add privacy compute function

shellmiao 10 місяців тому
батько
коміт
5136c50f5c

+ 62 - 0
applications/fedssl/client_with_pgfed.py

@@ -15,6 +15,7 @@ import utils
 from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
 from easyfl.client.base import BaseClient
 from easyfl.distributed.distributed import CPU
+from sklearn.neighbors import KernelDensity
 
 from client import FedSSLClient
 
@@ -22,6 +23,59 @@ logger = logging.getLogger(__name__)
 
 L2 = "l2"
 
+def compute_all_batch_privacy(model, data_loader, device):
+    model.eval()  # 切换到评估模式
+    batch_privacy_scores = []
+    
+    for batch, _ in data_loader:
+        batch = batch.to(device)
+        with torch.no_grad():
+            output_features = model(batch).cpu().numpy()
+
+        last_weight = list(model.parameters())[-1].detach().cpu().numpy().reshape(-1, 1)
+        
+        kde_output = KernelDensity(kernel='gaussian').fit(output_features)
+        kde_weight = KernelDensity(kernel='gaussian').fit(last_weight)
+        
+        combined_features = np.hstack((output_features, last_weight[:output_features.shape[0], :]))  # 确保形状匹配
+        kde_combined = KernelDensity(kernel='gaussian').fit(combined_features)
+        
+        log_p_x = kde_output.score_samples(output_features)
+        log_p_y = kde_weight.score_samples(last_weight[:output_features.shape[0], :])
+        log_p_xy = kde_combined.score_samples(combined_features)
+        
+        privacy = np.mean(log_p_xy - log_p_x - log_p_y)
+        batch_privacy_scores.append(privacy)
+    
+    # 根据互信息分数进行排序
+    sorted_batches = np.argsort(batch_privacy_scores)
+    
+    return sorted_batches, batch_privacy_scores
+
+def compute_batch_privacy(model, batch, device):
+    model.eval()
+    batch = batch.to(device)
+    
+    with torch.no_grad():
+        output_features = model(batch).cpu().numpy()
+
+    last_weight = list(model.parameters())[-1].detach().cpu().numpy().reshape(-1, 1)
+    
+    kde_output = KernelDensity(kernel='gaussian').fit(output_features)
+    kde_weight = KernelDensity(kernel='gaussian').fit(last_weight)
+    
+    combined_features = np.hstack((output_features, last_weight[:output_features.shape[0], :]))
+    kde_combined = KernelDensity(kernel='gaussian').fit(combined_features)
+    
+    log_p_x = kde_output.score_samples(output_features)
+    log_p_y = kde_weight.score_samples(last_weight[:output_features.shape[0], :])
+    log_p_xy = kde_combined.score_samples(combined_features)
+    
+    privacy = np.mean(log_p_xy - log_p_x - log_p_y)
+    
+    return privacy
+
+
 def model_dot_product(w1, w2, requires_grad=True):
     """ Return the sum of squared difference between two models. """
     dot_product = 0.0
@@ -56,6 +110,9 @@ class FedSSLWithPgFedClient(FedSSLClient):
             self.model.reset_key_encoder()
         self.train_loss = []
         self.model.to(device)
+
+        batch_privacy_scores = []
+
         old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
         for i in range(conf.local_epoch):
             batch_loss = []
@@ -63,6 +120,9 @@ class FedSSLWithPgFedClient(FedSSLClient):
                 x1, x2 = batched_x1.to(device), batched_x2.to(device)
                 optimizer.zero_grad()
 
+                privacy_score = compute_batch_privacy(self.model, x1, device)
+                batch_privacy_scores.append(privacy_score)
+
                 if conf.model in [model.MoCo, model.MoCoV2]:
                     loss = self.model(x1, x2, device)
                 elif conf.model == model.SimCLR:
@@ -88,6 +148,8 @@ class FedSSLWithPgFedClient(FedSSLClient):
             current_epoch_loss = sum(batch_loss) / len(batch_loss)
             self.train_loss.append(float(current_epoch_loss))
 
+        print(f"Batch privacy scores during training: {batch_privacy_scores}")
+
         self.loss_minus = 0.0
         test_num = 0
         optimizer.zero_grad()

+ 1 - 0
applications/fedssl/main.py

@@ -164,6 +164,7 @@ def run():
                                                                args.label_ratio)
         easyfl.register_dataset(train_data, test_data)
 
+    
     model = get_model(args.model, args.encoder_network, args.predictor_network)
     easyfl.register_model(model)
     if args.use_pgfed:

+ 1 - 0
python applicationsfedsslmain.py --.txt

@@ -0,0 +1 @@
+python applications/fedssl/main.py --task_id fedema233 --model simclr --aggregate_encoder online --update_encoder dynamic_ema_online --update_predictor dynamic_dapu --auto_scaler y --use_pgfed True --rounds 10 --gpu 1 --batch_size 16 --auto_scaler_target 0.7 2>&1 | tee log/${task_id}.log