|
@@ -16,7 +16,7 @@ from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA,
|
|
from easyfl.client.base import BaseClient
|
|
from easyfl.client.base import BaseClient
|
|
from easyfl.distributed.distributed import CPU
|
|
from easyfl.distributed.distributed import CPU
|
|
|
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
|
|
|
+logger = logging.getLogger(__name__) # 创建日志记录器
|
|
|
|
|
|
L2 = "l2"
|
|
L2 = "l2"
|
|
|
|
|
|
@@ -30,6 +30,8 @@ class FedSSLClient(BaseClient):
|
|
self.encoder_distances = []
|
|
self.encoder_distances = []
|
|
self.previous_trained_round = -1
|
|
self.previous_trained_round = -1
|
|
self.weight_scaler = None
|
|
self.weight_scaler = None
|
|
|
|
+ self.batch_id_privacy = []
|
|
|
|
+ self.batch_id = []
|
|
|
|
|
|
def decompression(self):
|
|
def decompression(self):
|
|
if self.model is None:
|
|
if self.model is None:
|
|
@@ -83,7 +85,7 @@ class FedSSLClient(BaseClient):
|
|
weight = min(1, self.weight_scaler * weight)
|
|
weight = min(1, self.weight_scaler * weight)
|
|
weight = 1 - weight
|
|
weight = 1 - weight
|
|
self.compressed_model = self.compressed_model.cpu()
|
|
self.compressed_model = self.compressed_model.cpu()
|
|
- online_encoder = self.compressed_model.online_encoder
|
|
|
|
|
|
+ online_encoder = self.compressed_model.online_encoder # 聚合的在线编码器
|
|
target_encoder = self._local_model.target_encoder
|
|
target_encoder = self._local_model.target_encoder
|
|
ema_updater = model.EMA(weight)
|
|
ema_updater = model.EMA(weight)
|
|
model.update_moving_average(ema_updater, online_encoder, self._local_model.online_encoder)
|
|
model.update_moving_average(ema_updater, online_encoder, self._local_model.online_encoder)
|
|
@@ -198,11 +200,14 @@ class FedSSLClient(BaseClient):
|
|
if conf.model in [model.MoCo, model.MoCoV2]:
|
|
if conf.model in [model.MoCo, model.MoCoV2]:
|
|
self.model.reset_key_encoder()
|
|
self.model.reset_key_encoder()
|
|
self.train_loss = []
|
|
self.train_loss = []
|
|
|
|
+
|
|
self.model.to(device)
|
|
self.model.to(device)
|
|
old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
|
|
for i in range(conf.local_epoch):
|
|
for i in range(conf.local_epoch):
|
|
batch_loss = []
|
|
batch_loss = []
|
|
- for (batched_x1, batched_x2), _ in self.train_loader:
|
|
|
|
|
|
+ for batch_index, ((batched_x1, batched_x2), _) in enumerate(self.train_loader):
|
|
|
|
+ # if conf.round_id > 0 and batch_index not in self.batch_id:
|
|
|
|
+ # continue
|
|
x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
x1, x2 = batched_x1.to(device), batched_x2.to(device)
|
|
optimizer.zero_grad()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
@@ -222,9 +227,17 @@ class FedSSLClient(BaseClient):
|
|
|
|
|
|
if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
|
|
if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
|
|
self.model.update_moving_average()
|
|
self.model.update_moving_average()
|
|
-
|
|
|
|
current_epoch_loss = sum(batch_loss) / len(batch_loss)
|
|
current_epoch_loss = sum(batch_loss) / len(batch_loss)
|
|
self.train_loss.append(float(current_epoch_loss))
|
|
self.train_loss.append(float(current_epoch_loss))
|
|
|
|
+ if conf.round_id == 0 and i == 4:
|
|
|
|
+ all_mi = utils.information(self.model.online_encoder, self.train_loader, device, self.cid)
|
|
|
|
+ # 使用enumerate获取元素和对应的序号,并使用sorted进行排序
|
|
|
|
+ sorted_ls_with_indices = sorted(enumerate(all_mi), key=lambda x: x[1])
|
|
|
|
+ # 提取排序后的序号
|
|
|
|
+ sorted_indices = [index for index, element in sorted_ls_with_indices]
|
|
|
|
+ self.batch_id_privacy = sorted_indices # ---------------batch_privacy--------
|
|
|
|
+ self.batch_id = self.batch_id_privacy[:int(len(self.batch_id_privacy) * 0.9)]
|
|
|
|
+
|
|
self.train_time = time.time() - start_time
|
|
self.train_time = time.time() - start_time
|
|
|
|
|
|
# store trained model locally
|
|
# store trained model locally
|
|
@@ -252,6 +265,7 @@ class FedSSLClient(BaseClient):
|
|
old_dict = old_model.state_dict()
|
|
old_dict = old_model.state_dict()
|
|
new_dict = new_model.state_dict()
|
|
new_dict = new_model.state_dict()
|
|
for name, param in old_model.named_parameters():
|
|
for name, param in old_model.named_parameters():
|
|
|
|
+ print(name)
|
|
if 'conv' in name and 'weight' in name:
|
|
if 'conv' in name and 'weight' in name:
|
|
total_distance += self._calculate_distance(old_dict[name].detach().clone().view(1, -1),
|
|
total_distance += self._calculate_distance(old_dict[name].detach().clone().view(1, -1),
|
|
new_dict[name].detach().clone().view(1, -1),
|
|
new_dict[name].detach().clone().view(1, -1),
|