end2end.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from __future__ import absolute_import
  2. from torch import nn
  3. from torch.autograd import Variable
  4. from torch.nn import functional as F
  5. from torch.nn import init
  6. import torch
  7. import torchvision
  8. import math
  9. from .resnet import *
  10. __all__ = ["End2End_AvgPooling"]
  11. class AvgPooling(nn.Module):
  12. def __init__(self, input_feature_size, embedding_fea_size=1024, dropout=0.5):
  13. super(self.__class__, self).__init__()
  14. # embedding
  15. self.embedding_fea_size = embedding_fea_size
  16. self.embedding = nn.Linear(input_feature_size, embedding_fea_size)
  17. self.embedding_bn = nn.BatchNorm1d(embedding_fea_size)
  18. init.kaiming_normal_(self.embedding.weight, mode='fan_out')
  19. init.constant_(self.embedding.bias, 0)
  20. init.constant_(self.embedding_bn.weight, 1)
  21. init.constant_(self.embedding_bn.bias, 0)
  22. self.drop = nn.Dropout(dropout)
  23. def forward(self, inputs):
  24. net = inputs.mean(dim=1)
  25. eval_features = F.normalize(net, p=2, dim=1)
  26. net = self.embedding(net)
  27. net = self.embedding_bn(net)
  28. net = F.normalize(net, p=2, dim=1)
  29. net = self.drop(net)
  30. return net, eval_features
  31. class End2End_AvgPooling(nn.Module):
  32. def __init__(self, dropout=0, embedding_fea_size=1024, fixed_layer=True):
  33. super(self.__class__, self).__init__()
  34. self.CNN = resnet50(dropout=dropout, fixed_layer=fixed_layer)
  35. self.avg_pooling = AvgPooling(input_feature_size=2048, embedding_fea_size=embedding_fea_size, dropout=dropout)
  36. def forward(self, x):
  37. assert len(x.data.shape) == 5
  38. # reshape (batch, samples, ...) ==> (batch * samples, ...)
  39. oriShape = x.data.shape
  40. x = x.view(-1, oriShape[2], oriShape[3], oriShape[4])
  41. # resnet encoding
  42. resnet_feature = self.CNN(x)
  43. # reshape back into (batch, samples, ...)
  44. resnet_feature = resnet_feature.view(oriShape[0], oriShape[1], -1)
  45. # avg pooling
  46. output = self.avg_pooling(resnet_feature)
  47. return output