123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- import torch
- import transform
- from model import SimSiam, MoCo
- from sklearn.neighbors import KernelDensity
- from tqdm import tqdm
- import numpy as np
- def get_transformation(model):
- if model == SimSiam:
- transformation = transform.SimSiamTransform
- elif model == MoCo:
- transformation = transform.MoCoTransform
- else:
- transformation = transform.SimCLRTransform
- return transformation
- def calculate_model_distance(m1, m2):
- distance, count = 0, 0
- d1, d2 = m1.state_dict(), m2.state_dict()
- for name, param in m1.named_parameters():
- if 'conv' in name and 'weight' in name:
- distance += torch.dist(d1[name].detach().clone().view(1, -1), d2[name].detach().clone().view(1, -1), 2)
- count += 1
- return distance / count
- def normalize(arr):
- maxx = max(arr)
- minn = min(arr)
- diff = maxx - minn
- if diff == 0:
- return arr
- return [(x - minn) / diff for x in arr]
- """ 添加的内容 """
- # 自定义钩子函数
- def hook_fn(module, input, output):
- module.output = output
- def kernel_mi(x, y, band=0.5):
- # 使用KDE估计变量x的概率密度函数
- kde_x = KernelDensity(kernel='gaussian', bandwidth=band)
- kde_x.fit(x)
- # 使用KDE估计变量y的概率密度函数
- kde_y = KernelDensity(kernel='gaussian', bandwidth=band)
- kde_y.fit(y)
- # 使用估计的概率密度函数计算联合概率密度函数
- xy = np.column_stack([x, y])
- kde_xy = KernelDensity(kernel='gaussian', bandwidth=band)
- kde_xy.fit(xy)
- # 计算互信息
- log_p_xy = kde_xy.score_samples(xy)
- log_p_x = kde_x.score_samples(x)
- log_p_y = kde_y.score_samples(y)
- mi = (log_p_xy - log_p_x - log_p_y).mean() # 假设样本与样本之间是独立同分布的
- return mi
- # the model is model.online_encoder
- def information(model, train_loader, device, cid):
- # 评估模式
- model.eval()
- # all batch 互信息
- all_mi = []
- # 不会进行计算梯度,也不会进行反向传播
- with torch.no_grad():
- for batch_index, ((batched_x1, batched_x2), _) in enumerate(train_loader):
- # 部署到device上
- data = batched_x1.to(device)
- # the feature batch*2048
- feature = model(data)
- batch_feature = feature.detach().clone()
- state_dict = model.state_dict()
- # 权重的处理
- for name, param in model.named_parameters():
- if 'fc.net.3.weight' in name:
- # feature_weight = state_dict[name].detach().clone()
- # feature_weight = feature_weight.reshape(-1, batch_feature.size(1))
- feature_weight = state_dict[name].detach().clone()
- feature_weight = feature_weight.t()
- feature_weight = feature_weight[:data.size(0), :]
- mi = kernel_mi(batch_feature.cpu().numpy(), feature_weight.cpu().numpy())
- print("client={}, epoch=4 batch={}, the mi = {}".format(cid, batch_index, mi))
- all_mi.append(mi)
- return all_mi