models.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. batch_size = 16
  5. class LocalModel(nn.Module):
  6. def __init__(self, base, predictor):
  7. super(LocalModel, self).__init__()
  8. self.base = base
  9. self.predictor = predictor
  10. def forward(self, x):
  11. out = self.base(x)
  12. out = self.predictor(out)
  13. return out
  14. # https://github.com/katsura-jp/fedavg.pytorch/blob/master/src/models/cnn.py
  15. class FedAvgCNN(nn.Module):
  16. def __init__(self, in_features=1, num_classes=10, dim=1024):
  17. super().__init__()
  18. self.conv1 = nn.Sequential(
  19. nn.Conv2d(in_features,
  20. 32,
  21. kernel_size=5,
  22. padding=0,
  23. stride=1,
  24. bias=True),
  25. nn.ReLU(inplace=True),
  26. nn.MaxPool2d(kernel_size=(2, 2))
  27. )
  28. self.conv2 = nn.Sequential(
  29. nn.Conv2d(32,
  30. 64,
  31. kernel_size=5,
  32. padding=0,
  33. stride=1,
  34. bias=True),
  35. nn.ReLU(inplace=True),
  36. nn.MaxPool2d(kernel_size=(2, 2))
  37. )
  38. self.fc1 = nn.Sequential(
  39. nn.Linear(dim, 512),
  40. nn.ReLU(inplace=True)
  41. )
  42. self.fc = nn.Linear(512, num_classes)
  43. def forward(self, x):
  44. out = self.conv1(x)
  45. out = self.conv2(out)
  46. out = torch.flatten(out, 1)
  47. out = self.fc1(out)
  48. out = self.fc(out)
  49. return out