resnet.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from __future__ import absolute_import
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from torch.nn import init
  5. import torchvision
  6. __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
  7. 'resnet152']
  8. class ResNet(nn.Module):
  9. __factory = {
  10. 18: torchvision.models.resnet18,
  11. 34: torchvision.models.resnet34,
  12. 50: torchvision.models.resnet50,
  13. 101: torchvision.models.resnet101,
  14. 152: torchvision.models.resnet152,
  15. }
  16. def __init__(self, depth, pretrained=True, cut_at_pooling=False,
  17. num_features=0, norm=False, dropout=0, num_classes=0):
  18. super(ResNet, self).__init__()
  19. self.depth = depth
  20. self.pretrained = pretrained
  21. self.cut_at_pooling = cut_at_pooling
  22. # Construct base (pretrained) resnet
  23. if depth not in ResNet.__factory:
  24. raise KeyError("Unsupported depth:", depth)
  25. self.base = ResNet.__factory[depth](pretrained=pretrained)
  26. if not self.cut_at_pooling:
  27. self.num_features = num_features
  28. self.norm = norm
  29. self.dropout = dropout
  30. self.has_embedding = num_features > 0
  31. self.num_classes = num_classes
  32. out_planes = self.base.fc.in_features
  33. # Append new layers
  34. if self.has_embedding:
  35. self.feat = nn.Linear(out_planes, self.num_features)
  36. self.feat_bn = nn.BatchNorm1d(self.num_features)
  37. init.kaiming_normal(self.feat.weight, mode='fan_out')
  38. init.constant(self.feat.bias, 0)
  39. init.constant(self.feat_bn.weight, 1)
  40. init.constant(self.feat_bn.bias, 0)
  41. else:
  42. # Change the num_features to CNN output channels
  43. self.num_features = out_planes
  44. if self.dropout > 0:
  45. self.drop = nn.Dropout(self.dropout)
  46. if self.num_classes > 0:
  47. self.classifier = nn.Linear(self.num_features, self.num_classes)
  48. init.normal(self.classifier.weight, std=0.001)
  49. init.constant(self.classifier.bias, 0)
  50. if not self.pretrained:
  51. self.reset_params()
  52. def forward(self, x):
  53. for name, module in self.base._modules.items():
  54. if name == 'avgpool':
  55. break
  56. x = module(x)
  57. if self.cut_at_pooling:
  58. return x
  59. x = F.avg_pool2d(x, x.size()[2:])
  60. x = x.view(x.size(0), -1)
  61. return x
  62. def reset_params(self):
  63. for m in self.modules():
  64. if isinstance(m, nn.Conv2d):
  65. init.kaiming_normal(m.weight, mode='fan_out')
  66. if m.bias is not None:
  67. init.constant(m.bias, 0)
  68. elif isinstance(m, nn.BatchNorm2d):
  69. init.constant(m.weight, 1)
  70. init.constant(m.bias, 0)
  71. elif isinstance(m, nn.Linear):
  72. init.normal(m.weight, std=0.001)
  73. if m.bias is not None:
  74. init.constant(m.bias, 0)
  75. def resnet18(**kwargs):
  76. return ResNet(18, **kwargs)
  77. def resnet34(**kwargs):
  78. return ResNet(34, **kwargs)
  79. def resnet50(**kwargs):
  80. return ResNet(50, **kwargs)
  81. def resnet101(**kwargs):
  82. return ResNet(101, **kwargs)
  83. def resnet152(**kwargs):
  84. return ResNet(152, **kwargs)