vgg9.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. import torch.nn as nn
  3. import math
  4. from easyfl.models import BaseModel
  5. cfg = {
  6. 'VGG9': [32, 64, 'M', 128, 128, 'M', 256, 256, 'M'],
  7. }
  8. def make_layers(cfg, batch_norm):
  9. layers = []
  10. in_channels = 3
  11. for v in cfg:
  12. if v == 'M':
  13. layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
  14. else:
  15. conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
  16. if batch_norm:
  17. layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
  18. else:
  19. layers += [conv2d, nn.ReLU(inplace=True)]
  20. in_channels = v
  21. return nn.Sequential(*layers)
  22. class Model(BaseModel):
  23. def __init__(self, features=make_layers(cfg['VGG9'], batch_norm=False), num_classes=10):
  24. super(Model, self).__init__()
  25. self.features = features
  26. self.classifier = nn.Sequential(
  27. nn.Dropout(p=0.1),
  28. nn.Linear(4096, 512),
  29. nn.ReLU(True),
  30. nn.Dropout(p=0.1),
  31. nn.Linear(512, 512),
  32. nn.ReLU(True),
  33. nn.Linear(512, num_classes),
  34. )
  35. self._initialize_weights()
  36. def forward(self, x):
  37. x = self.features(x)
  38. x = x.view(x.size(0), -1)
  39. x = self.classifier(x)
  40. return x
  41. def _initialize_weights(self):
  42. for m in self.modules():
  43. if isinstance(m, nn.Conv2d):
  44. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  45. m.weight.data.normal_(0, math.sqrt(2. / n))
  46. if m.bias is not None:
  47. m.bias.data.zero_()
  48. elif isinstance(m, nn.BatchNorm2d):
  49. m.reset_parameters()
  50. elif isinstance(m, nn.Linear):
  51. m.weight.data.normal_(0, 0.01)
  52. m.bias.data.zero_()
  53. def VGG9(batch_norm=False, **kwargs):
  54. model = Model(make_layers(cfg['VGG9'], batch_norm), **kwargs)
  55. return model