123456789101112131415161718192021222324252627282930313233343536 |
- from torch import nn
- import torch.nn.functional as F
- from easyfl.models.model import BaseModel
- class Model(BaseModel):
- def __init__(self):
- super(Model, self).__init__()
- self.conv1 = nn.Conv2d(1, 32, 5, padding=(2, 2))
- self.conv2 = nn.Conv2d(32, 64, 5, padding=(2, 2))
- self.fc1 = nn.Linear(7 * 7 * 64, 2048)
- self.fc2 = nn.Linear(2048, 62)
- self.init_weights()
- def forward(self, x):
- x = F.relu(self.conv1(x))
- x = F.max_pool2d(x, 2, 2)
- x = F.relu(self.conv2(x))
- x = F.max_pool2d(x, 2, 2)
- x = x.view(-1, 7 * 7 * 64)
- x = F.relu(self.fc1(x))
- x = self.fc2(x)
- return x
- def init_weights(self):
- init_range = 0.1
- self.conv1.weight.data.uniform_(-init_range, init_range)
- self.conv1.bias.data.zero_()
- self.conv2.weight.data.uniform_(-init_range, init_range)
- self.conv2.bias.data.zero_()
- self.fc1.weight.data.uniform_(-init_range, init_range)
- self.fc1.bias.data.zero_()
- self.fc2.weight.data.uniform_(-init_range, init_range)
- self.fc2.bias.data.zero_()
|