123456789101112131415161718192021222324252627282930 |
- from __future__ import absolute_import
- from collections import OrderedDict
- from torch.autograd import Variable
- from ..utils import to_torch
- import torch
- def extract_cnn_feature(model, inputs, modules=None):
- with torch.no_grad():
- model.eval()
- inputs = to_torch(inputs)
- inputs = Variable(inputs)
- if modules is None:
- fcs, pool5s = model(inputs)
- fcs = fcs.data.cpu()
- pool5s = pool5s.data.cpu()
- return fcs, pool5s
- # Register forward hook for each module
- outputs = OrderedDict()
- handles = []
- for m in modules:
- outputs[id(m)] = None
- def func(m, i, o): outputs[id(m)] = o.data.cpu()
- handles.append(m.register_forward_hook(func))
- model(inputs)
- for h in handles:
- h.remove()
- return list(outputs.values())
|