simple_cnn.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import torch
  2. import torch.nn.functional as F
  3. from torch import nn
  4. from easyfl.models import BaseModel
  5. class Model(BaseModel):
  6. def __init__(self, channels=32):
  7. super(Model, self).__init__()
  8. self.num_channels = channels
  9. self.conv1 = nn.Conv2d(3, self.num_channels, 3, stride=1)
  10. self.conv2 = nn.Conv2d(self.num_channels, self.num_channels * 2, 3, stride=1)
  11. self.conv3 = nn.Conv2d(self.num_channels * 2, self.num_channels * 2, 3, stride=1)
  12. # 2 fully connected layers to transform the output of the convolution layers to the final output
  13. self.fc1 = nn.Linear(4 * 4 * self.num_channels * 2, self.num_channels * 2)
  14. self.fc2 = nn.Linear(self.num_channels * 2, 10)
  15. def forward(self, s):
  16. s = self.conv1(s) # batch_size x num_channels x 32 x 32
  17. s = F.relu(F.max_pool2d(s, 2)) # batch_size x num_channels x 16 x 16
  18. s = self.conv2(s) # batch_size x num_channels*2 x 16 x 16
  19. s = F.relu(F.max_pool2d(s, 2)) # batch_size x num_channels*2 x 8 x 8
  20. s = self.conv3(s) # batch_size x num_channels*2 x 8 x 8
  21. # s = F.relu(F.max_pool2d(s, 2)) # batch_size x num_channels*2 x 4 x 4
  22. # flatten the output for each image
  23. s = s.view(-1, 4 * 4 * self.num_channels * 2) # batch_size x 4*4*num_channels*4
  24. # apply 2 fully connected layers with dropout
  25. s = F.relu(self.fc1(s))
  26. s = self.fc2(s) # batch_size x 10
  27. return s