model.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import torch.nn as nn
  2. from torch.nn import init
  3. from torchvision import models
  4. from easyfl.models import BaseModel
  5. def weights_init_kaiming(m):
  6. classname = m.__class__.__name__
  7. if classname.find('Conv') != -1:
  8. init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal.
  9. elif classname.find('Linear') != -1:
  10. init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
  11. init.constant_(m.bias.data, 0.0)
  12. elif classname.find('BatchNorm1d') != -1:
  13. init.normal_(m.weight.data, 1.0, 0.02)
  14. init.constant_(m.bias.data, 0.0)
  15. def weights_init_classifier(m):
  16. classname = m.__class__.__name__
  17. if classname.find('Linear') != -1:
  18. init.normal_(m.weight.data, std=0.001)
  19. init.constant_(m.bias.data, 0.0)
  20. class ClassBlock(nn.Module):
  21. def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True,
  22. return_f=False):
  23. super(ClassBlock, self).__init__()
  24. self.return_f = return_f
  25. add_block = []
  26. if linear:
  27. add_block += [nn.Linear(input_dim, num_bottleneck)]
  28. else:
  29. num_bottleneck = input_dim
  30. if bnorm:
  31. add_block += [nn.BatchNorm1d(num_bottleneck)]
  32. if relu:
  33. add_block += [nn.LeakyReLU(0.1)]
  34. if droprate > 0:
  35. add_block += [nn.Dropout(p=droprate)]
  36. add_block = nn.Sequential(*add_block)
  37. add_block.apply(weights_init_kaiming)
  38. classifier = []
  39. classifier += [nn.Linear(num_bottleneck, class_num)]
  40. classifier = nn.Sequential(*classifier)
  41. classifier.apply(weights_init_classifier)
  42. self.add_block = add_block
  43. self.classifier = classifier
  44. def forward(self, x):
  45. x = self.add_block(x)
  46. if self.return_f:
  47. f = x
  48. x = self.classifier(x)
  49. return x, f
  50. else:
  51. x = self.classifier(x)
  52. return x
  53. # Define the ResNet50-based Model
  54. class Model(BaseModel):
  55. def __init__(self, class_num=0, droprate=0.5, stride=2):
  56. super(Model, self).__init__()
  57. model_ft = models.resnet50(pretrained=True)
  58. self.class_num = class_num
  59. if stride == 1:
  60. model_ft.layer4[0].downsample[0].stride = (1, 1)
  61. model_ft.layer4[0].conv2.stride = (1, 1)
  62. model_ft.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  63. self.model = model_ft
  64. if self.class_num != 0:
  65. self.classifier = ClassBlock(2048, class_num, droprate)
  66. else:
  67. self.classifier = ClassBlock(2048, 10, droprate) # 10 is not effective because classifier is replaced below
  68. self.classifier.classifier = nn.Sequential()
  69. def forward(self, x):
  70. x = self.model.conv1(x)
  71. x = self.model.bn1(x)
  72. x = self.model.relu(x)
  73. x = self.model.maxpool(x)
  74. x = self.model.layer1(x)
  75. x = self.model.layer2(x)
  76. x = self.model.layer3(x)
  77. x = self.model.layer4(x)
  78. x = self.model.avgpool(x)
  79. x = x.view(x.size(0), x.size(1))
  80. x = self.classifier(x)
  81. return x
  82. def get_classifier(class_num, num_bottleneck=512):
  83. classifier = []
  84. classifier += [nn.Linear(num_bottleneck, class_num)]
  85. classifier = nn.Sequential(*classifier)
  86. classifier.apply(weights_init_classifier)
  87. return classifier