utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import torch
  2. import transform
  3. from model import SimSiam, MoCo
  4. from sklearn.neighbors import KernelDensity
  5. from tqdm import tqdm
  6. import numpy as np
  7. def get_transformation(model):
  8. if model == SimSiam:
  9. transformation = transform.SimSiamTransform
  10. elif model == MoCo:
  11. transformation = transform.MoCoTransform
  12. else:
  13. transformation = transform.SimCLRTransform
  14. return transformation
  15. def calculate_model_distance(m1, m2):
  16. distance, count = 0, 0
  17. d1, d2 = m1.state_dict(), m2.state_dict()
  18. for name, param in m1.named_parameters():
  19. if 'conv' in name and 'weight' in name:
  20. distance += torch.dist(d1[name].detach().clone().view(1, -1), d2[name].detach().clone().view(1, -1), 2)
  21. count += 1
  22. return distance / count
  23. def normalize(arr):
  24. maxx = max(arr)
  25. minn = min(arr)
  26. diff = maxx - minn
  27. if diff == 0:
  28. return arr
  29. return [(x - minn) / diff for x in arr]
  30. """ 添加的内容 """
  31. # 自定义钩子函数
  32. def hook_fn(module, input, output):
  33. module.output = output
  34. def kernel_mi(x, y, band=0.5):
  35. # 使用KDE估计变量x的概率密度函数
  36. kde_x = KernelDensity(kernel='gaussian', bandwidth=band)
  37. kde_x.fit(x)
  38. # 使用KDE估计变量y的概率密度函数
  39. kde_y = KernelDensity(kernel='gaussian', bandwidth=band)
  40. kde_y.fit(y)
  41. # 使用估计的概率密度函数计算联合概率密度函数
  42. xy = np.column_stack([x, y])
  43. kde_xy = KernelDensity(kernel='gaussian', bandwidth=band)
  44. kde_xy.fit(xy)
  45. # 计算互信息
  46. log_p_xy = kde_xy.score_samples(xy)
  47. log_p_x = kde_x.score_samples(x)
  48. log_p_y = kde_y.score_samples(y)
  49. mi = (log_p_xy - log_p_x - log_p_y).mean() # 假设样本与样本之间是独立同分布的
  50. return mi
  51. # the model is model.online_encoder
  52. def information(model, train_loader, device, cid):
  53. # 评估模式
  54. model.eval()
  55. # all batch 互信息
  56. all_mi = []
  57. # 不会进行计算梯度,也不会进行反向传播
  58. with torch.no_grad():
  59. for batch_index, ((batched_x1, batched_x2), _) in enumerate(train_loader):
  60. # 部署到device上
  61. data = batched_x1.to(device)
  62. # the feature batch*2048
  63. feature = model(data)
  64. batch_feature = feature.detach().clone()
  65. state_dict = model.state_dict()
  66. # 权重的处理
  67. for name, param in model.named_parameters():
  68. if 'fc.net.3.weight' in name:
  69. # feature_weight = state_dict[name].detach().clone()
  70. # feature_weight = feature_weight.reshape(-1, batch_feature.size(1))
  71. feature_weight = state_dict[name].detach().clone()
  72. feature_weight = feature_weight.t()
  73. feature_weight = feature_weight[:data.size(0), :]
  74. mi = kernel_mi(batch_feature.cpu().numpy(), feature_weight.cpu().numpy())
  75. print("client={}, epoch=4 batch={}, the mi = {}".format(cid, batch_index, mi))
  76. all_mi.append(mi)
  77. return all_mi