metrics.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """This is code based on https://sudomake.ai/inception-score-explained/."""
  2. import torch
  3. import torchvision
  4. from collections import defaultdict
  5. class InceptionScore(torch.nn.Module):
  6. """Class that manages and returns the inception score of images."""
  7. def __init__(self, batch_size=32, setup=dict(device=torch.device('cpu'), dtype=torch.float)):
  8. """Initialize with setup and target inception batch size."""
  9. super().__init__()
  10. self.preprocessing = torch.nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False)
  11. self.model = torchvision.models.inception_v3(pretrained=True).to(**setup)
  12. self.model.eval()
  13. self.batch_size = batch_size
  14. def forward(self, image_batch):
  15. """Image batch should have dimensions BCHW and should be normalized.
  16. B should be divisible by self.batch_size.
  17. """
  18. B, C, H, W = image_batch.shape
  19. batches = B // self.batch_size
  20. scores = []
  21. for batch in range(batches):
  22. input = self.preprocessing(image_batch[batch * self.batch_size: (batch + 1) * self.batch_size])
  23. scores.append(self.model(input))
  24. prob_yx = torch.nn.functional.softmax(torch.cat(scores, 0), dim=1)
  25. entropy = torch.where(prob_yx > 0, -prob_yx * prob_yx.log(), torch.zeros_like(prob_yx))
  26. return entropy.mean()
  27. def psnr(img_batch, ref_batch, batched=False, factor=1.0):
  28. """Standard PSNR."""
  29. def get_psnr(img_in, img_ref):
  30. mse = ((img_in - img_ref)**2).mean()
  31. if mse > 0 and torch.isfinite(mse):
  32. return (10 * torch.log10(factor**2 / mse))
  33. elif not torch.isfinite(mse):
  34. return img_batch.new_tensor(float('nan'))
  35. else:
  36. return img_batch.new_tensor(float('inf'))
  37. if batched:
  38. psnr = get_psnr(img_batch.detach(), ref_batch)
  39. else:
  40. [B, C, m, n] = img_batch.shape
  41. psnrs = []
  42. for sample in range(B):
  43. psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :]))
  44. psnr = torch.stack(psnrs, dim=0).mean()
  45. return psnr.item()
  46. def total_variation(x):
  47. """Anisotropic TV."""
  48. dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
  49. dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
  50. return dx + dy
  51. def activation_errors(model, x1, x2):
  52. """Compute activation-level error metrics for every module in the network."""
  53. model.eval()
  54. device = next(model.parameters()).device
  55. hooks = []
  56. data = defaultdict(dict)
  57. inputs = torch.cat((x1, x2), dim=0)
  58. separator = x1.shape[0]
  59. def check_activations(self, input, output):
  60. module_name = str(*[name for name, mod in model.named_modules() if self is mod])
  61. try:
  62. layer_inputs = input[0].detach()
  63. residual = (layer_inputs[:separator] - layer_inputs[separator:]).pow(2)
  64. se_error = residual.sum()
  65. mse_error = residual.mean()
  66. sim = torch.nn.functional.cosine_similarity(layer_inputs[:separator].flatten(),
  67. layer_inputs[separator:].flatten(),
  68. dim=0, eps=1e-8).detach()
  69. data['se'][module_name] = se_error.item()
  70. data['mse'][module_name] = mse_error.item()
  71. data['sim'][module_name] = sim.item()
  72. except (KeyboardInterrupt, SystemExit):
  73. raise
  74. except AttributeError:
  75. pass
  76. for name, module in model.named_modules():
  77. hooks.append(module.register_forward_hook(check_activations))
  78. try:
  79. outputs = model(inputs.to(device))
  80. for hook in hooks:
  81. hook.remove()
  82. except Exception as e:
  83. for hook in hooks:
  84. hook.remove()
  85. raise
  86. return data