12345678910111213141516171819202122232425262728293031323334353637383940 |
- import torch
- import torch.nn.functional as F
- from torch import nn
- from easyfl.models import BaseModel
- class Model(BaseModel):
- def __init__(self, channels=32):
- super(Model, self).__init__()
- self.num_channels = channels
- self.conv1 = nn.Conv2d(3, self.num_channels, 3, stride=1)
- self.conv2 = nn.Conv2d(self.num_channels, self.num_channels * 2, 3, stride=1)
- self.conv3 = nn.Conv2d(self.num_channels * 2, self.num_channels * 2, 3, stride=1)
-
- self.fc1 = nn.Linear(4 * 4 * self.num_channels * 2, self.num_channels * 2)
- self.fc2 = nn.Linear(self.num_channels * 2, 10)
- def forward(self, s):
- s = self.conv1(s)
- s = F.relu(F.max_pool2d(s, 2))
- s = self.conv2(s)
- s = F.relu(F.max_pool2d(s, 2))
- s = self.conv3(s)
-
-
- s = s.view(-1, 4 * 4 * self.num_channels * 2)
-
- s = F.relu(self.fc1(s))
- s = self.fc2(s)
- return s
|