12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- batch_size = 16
- class LocalModel(nn.Module):
- def __init__(self, base, predictor):
- super(LocalModel, self).__init__()
- self.base = base
- self.predictor = predictor
-
- def forward(self, x):
- out = self.base(x)
- out = self.predictor(out)
- return out
-
- # https://github.com/katsura-jp/fedavg.pytorch/blob/master/src/models/cnn.py
- class FedAvgCNN(nn.Module):
- def __init__(self, in_features=1, num_classes=10, dim=1024):
- super().__init__()
- self.conv1 = nn.Sequential(
- nn.Conv2d(in_features,
- 32,
- kernel_size=5,
- padding=0,
- stride=1,
- bias=True),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=(2, 2))
- )
- self.conv2 = nn.Sequential(
- nn.Conv2d(32,
- 64,
- kernel_size=5,
- padding=0,
- stride=1,
- bias=True),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=(2, 2))
- )
- self.fc1 = nn.Sequential(
- nn.Linear(dim, 512),
- nn.ReLU(inplace=True)
- )
- self.fc = nn.Linear(512, num_classes)
- def forward(self, x):
- out = self.conv1(x)
- out = self.conv2(out)
- out = torch.flatten(out, 1)
- out = self.fc1(out)
- out = self.fc(out)
- return out
|