12345678910111213141516171819202122 |
- from pkg_resources import require
- import torch
- def l2_squared_diff(w1, w2, requires_grad=True):
- """ Return the sum of squared difference between two models. """
- diff = 0.0
- for p1, p2 in zip(w1.parameters(), w2.parameters()):
- if requires_grad:
- diff += torch.sum(torch.pow(p1-p2, 2))
- else:
- diff += torch.sum(torch.pow(p1.data-p2.data, 2))
- return diff
- def model_dot_product(w1, w2, requires_grad=True):
- """ Return the sum of squared difference between two models. """
- dot_product = 0.0
- for p1, p2 in zip(w1.parameters(), w2.parameters()):
- if requires_grad:
- dot_product += torch.sum(p1 * p2)
- else:
- dot_product += torch.sum(p1.data * p2.data)
- return dot_product
|