modules.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """For monkey-patching into meta-learning frameworks."""
  2. import torch
  3. import torch.nn.functional as F
  4. from collections import OrderedDict
  5. from functools import partial
  6. import warnings
  7. from .consts import BENCHMARK
  8. torch.backends.cudnn.benchmark = BENCHMARK
  9. DEBUG = False # Emit warning messages when patching. Use this to bootstrap new architectures.
  10. class MetaMonkey(torch.nn.Module):
  11. """Trace a networks and then replace its module calls with functional calls.
  12. This allows for backpropagation w.r.t to weights for "normal" PyTorch networks.
  13. """
  14. def __init__(self, net):
  15. """Init with network."""
  16. super().__init__()
  17. self.net = net
  18. self.parameters = OrderedDict(net.named_parameters())
  19. def forward(self, inputs, parameters=None):
  20. """Live Patch ... :> ..."""
  21. # If no parameter dictionary is given, everything is normal
  22. if parameters is None:
  23. return self.net(inputs)
  24. # But if not ...
  25. param_gen = iter(parameters.values())
  26. method_pile = []
  27. counter = 0
  28. for name, module in self.net.named_modules():
  29. if isinstance(module, torch.nn.Conv2d):
  30. ext_weight = next(param_gen)
  31. if module.bias is not None:
  32. ext_bias = next(param_gen)
  33. else:
  34. ext_bias = None
  35. method_pile.append(module.forward)
  36. module.forward = partial(F.conv2d, weight=ext_weight, bias=ext_bias, stride=module.stride,
  37. padding=module.padding, dilation=module.dilation, groups=module.groups)
  38. elif isinstance(module, torch.nn.BatchNorm2d):
  39. if module.momentum is None:
  40. exponential_average_factor = 0.0
  41. else:
  42. exponential_average_factor = module.momentum
  43. if module.training and module.track_running_stats:
  44. if module.num_batches_tracked is not None:
  45. module.num_batches_tracked += 1
  46. if module.momentum is None: # use cumulative moving average
  47. exponential_average_factor = 1.0 / float(module.num_batches_tracked)
  48. else: # use exponential moving average
  49. exponential_average_factor = module.momentum
  50. ext_weight = next(param_gen)
  51. ext_bias = next(param_gen)
  52. method_pile.append(module.forward)
  53. module.forward = partial(F.batch_norm, running_mean=module.running_mean, running_var=module.running_var,
  54. weight=ext_weight, bias=ext_bias,
  55. training=module.training or not module.track_running_stats,
  56. momentum=exponential_average_factor, eps=module.eps)
  57. elif isinstance(module, torch.nn.Linear):
  58. lin_weights = next(param_gen)
  59. lin_bias = next(param_gen)
  60. method_pile.append(module.forward)
  61. module.forward = partial(F.linear, weight=lin_weights, bias=lin_bias)
  62. elif next(module.parameters(), None) is None:
  63. # Pass over modules that do not contain parameters
  64. pass
  65. elif isinstance(module, torch.nn.Sequential):
  66. # Pass containers
  67. pass
  68. else:
  69. # Warn for other containers
  70. if DEBUG:
  71. warnings.warn(f'Patching for module {module.__class__} is not implemented.')
  72. output = self.net(inputs)
  73. # Undo Patch
  74. for name, module in self.net.named_modules():
  75. if isinstance(module, torch.nn.modules.conv.Conv2d):
  76. module.forward = method_pile.pop(0)
  77. elif isinstance(module, torch.nn.BatchNorm2d):
  78. module.forward = method_pile.pop(0)
  79. elif isinstance(module, torch.nn.Linear):
  80. module.forward = method_pile.pop(0)
  81. return output