utils.py 873 B

123456789101112131415161718192021222324252627282930313233
  1. import torch
  2. import transform
  3. from model import SimSiam, MoCo
  4. def get_transformation(model):
  5. if model == SimSiam:
  6. transformation = transform.SimSiamTransform
  7. elif model == MoCo:
  8. transformation = transform.MoCoTransform
  9. else:
  10. transformation = transform.SimCLRTransform
  11. return transformation
  12. def calculate_model_distance(m1, m2):
  13. distance, count = 0, 0
  14. d1, d2 = m1.state_dict(), m2.state_dict()
  15. for name, param in m1.named_parameters():
  16. if 'conv' in name and 'weight' in name:
  17. distance += torch.dist(d1[name].detach().clone().view(1, -1), d2[name].detach().clone().view(1, -1), 2)
  18. count += 1
  19. return distance / count
  20. def normalize(arr):
  21. maxx = max(arr)
  22. minn = min(arr)
  23. diff = maxx - minn
  24. if diff == 0:
  25. return arr
  26. return [(x - minn) / diff for x in arr]