tensor_utils.py 770 B

12345678910111213141516171819202122
  1. from pkg_resources import require
  2. import torch
  3. def l2_squared_diff(w1, w2, requires_grad=True):
  4. """ Return the sum of squared difference between two models. """
  5. diff = 0.0
  6. for p1, p2 in zip(w1.parameters(), w2.parameters()):
  7. if requires_grad:
  8. diff += torch.sum(torch.pow(p1-p2, 2))
  9. else:
  10. diff += torch.sum(torch.pow(p1.data-p2.data, 2))
  11. return diff
  12. def model_dot_product(w1, w2, requires_grad=True):
  13. """ Return the sum of squared difference between two models. """
  14. dot_product = 0.0
  15. for p1, p2 in zip(w1.parameters(), w2.parameters()):
  16. if requires_grad:
  17. dot_product += torch.sum(p1 * p2)
  18. else:
  19. dot_product += torch.sum(p1.data * p2.data)
  20. return dot_product