import torch import transform from model import SimSiam, MoCo from sklearn.neighbors import KernelDensity from tqdm import tqdm import numpy as np def get_transformation(model): if model == SimSiam: transformation = transform.SimSiamTransform elif model == MoCo: transformation = transform.MoCoTransform else: transformation = transform.SimCLRTransform return transformation def calculate_model_distance(m1, m2): distance, count = 0, 0 d1, d2 = m1.state_dict(), m2.state_dict() for name, param in m1.named_parameters(): if 'conv' in name and 'weight' in name: distance += torch.dist(d1[name].detach().clone().view(1, -1), d2[name].detach().clone().view(1, -1), 2) count += 1 return distance / count def normalize(arr): maxx = max(arr) minn = min(arr) diff = maxx - minn if diff == 0: return arr return [(x - minn) / diff for x in arr] """ 添加的内容 """ # 自定义钩子函数 def hook_fn(module, input, output): module.output = output def kernel_mi(x, y, band=0.5): # 使用KDE估计变量x的概率密度函数 kde_x = KernelDensity(kernel='gaussian', bandwidth=band) kde_x.fit(x) # 使用KDE估计变量y的概率密度函数 kde_y = KernelDensity(kernel='gaussian', bandwidth=band) kde_y.fit(y) # 使用估计的概率密度函数计算联合概率密度函数 xy = np.column_stack([x, y]) kde_xy = KernelDensity(kernel='gaussian', bandwidth=band) kde_xy.fit(xy) # 计算互信息 log_p_xy = kde_xy.score_samples(xy) log_p_x = kde_x.score_samples(x) log_p_y = kde_y.score_samples(y) mi = (log_p_xy - log_p_x - log_p_y).mean() # 假设样本与样本之间是独立同分布的 return mi # the model is model.online_encoder def information(model, train_loader, device, cid): # 评估模式 model.eval() # all batch 互信息 all_mi = [] # 不会进行计算梯度,也不会进行反向传播 with torch.no_grad(): for batch_index, ((batched_x1, batched_x2), _) in enumerate(train_loader): # 部署到device上 data = batched_x1.to(device) # the feature batch*2048 feature = model(data) batch_feature = feature.detach().clone() state_dict = model.state_dict() # 权重的处理 for name, param in model.named_parameters(): if 'fc.net.3.weight' in name: # feature_weight = state_dict[name].detach().clone() # feature_weight = feature_weight.reshape(-1, batch_feature.size(1)) feature_weight = state_dict[name].detach().clone() feature_weight = feature_weight.t() feature_weight = feature_weight[:data.size(0), :] mi = kernel_mi(batch_feature.cpu().numpy(), feature_weight.cpu().numpy()) print("client={}, epoch=4 batch={}, the mi = {}".format(cid, batch_index, mi)) all_mi.append(mi) return all_mi