lenet.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from torch import nn
  2. import torch.nn.functional as F
  3. from easyfl.models.model import BaseModel
  4. class Model(BaseModel):
  5. def __init__(self):
  6. super(Model, self).__init__()
  7. self.conv1 = nn.Conv2d(1, 32, 5, padding=(2, 2))
  8. self.conv2 = nn.Conv2d(32, 64, 5, padding=(2, 2))
  9. self.fc1 = nn.Linear(7 * 7 * 64, 2048)
  10. self.fc2 = nn.Linear(2048, 62)
  11. self.init_weights()
  12. def forward(self, x):
  13. x = F.relu(self.conv1(x))
  14. x = F.max_pool2d(x, 2, 2)
  15. x = F.relu(self.conv2(x))
  16. x = F.max_pool2d(x, 2, 2)
  17. x = x.view(-1, 7 * 7 * 64)
  18. x = F.relu(self.fc1(x))
  19. x = self.fc2(x)
  20. return x
  21. def init_weights(self):
  22. init_range = 0.1
  23. self.conv1.weight.data.uniform_(-init_range, init_range)
  24. self.conv1.bias.data.zero_()
  25. self.conv2.weight.data.uniform_(-init_range, init_range)
  26. self.conv2.bias.data.zero_()
  27. self.fc1.weight.data.uniform_(-init_range, init_range)
  28. self.fc1.bias.data.zero_()
  29. self.fc2.weight.data.uniform_(-init_range, init_range)
  30. self.fc2.bias.data.zero_()