cnn.py 882 B

123456789101112131415161718192021222324252627282930
  1. from __future__ import absolute_import
  2. from collections import OrderedDict
  3. from torch.autograd import Variable
  4. from ..utils import to_torch
  5. import torch
  6. def extract_cnn_feature(model, inputs, modules=None):
  7. with torch.no_grad():
  8. model.eval()
  9. inputs = to_torch(inputs)
  10. inputs = Variable(inputs)
  11. if modules is None:
  12. fcs, pool5s = model(inputs)
  13. fcs = fcs.data.cpu()
  14. pool5s = pool5s.data.cpu()
  15. return fcs, pool5s
  16. # Register forward hook for each module
  17. outputs = OrderedDict()
  18. handles = []
  19. for m in modules:
  20. outputs[id(m)] = None
  21. def func(m, i, o): outputs[id(m)] = o.data.cpu()
  22. handles.append(m.register_forward_hook(func))
  23. model(inputs)
  24. for h in handles:
  25. h.remove()
  26. return list(outputs.values())