import torch import torch.nn as nn import torch.nn.functional as F batch_size = 16 class LocalModel(nn.Module): def __init__(self, base, predictor): super(LocalModel, self).__init__() self.base = base self.predictor = predictor def forward(self, x): out = self.base(x) out = self.predictor(out) return out # https://github.com/katsura-jp/fedavg.pytorch/blob/master/src/models/cnn.py class FedAvgCNN(nn.Module): def __init__(self, in_features=1, num_classes=10, dim=1024): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_features, 32, kernel_size=5, padding=0, stride=1, bias=True), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 2)) ) self.conv2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=5, padding=0, stride=1, bias=True), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 2)) ) self.fc1 = nn.Sequential( nn.Linear(dim, 512), nn.ReLU(inplace=True) ) self.fc = nn.Linear(512, num_classes) def forward(self, x): out = self.conv1(x) out = self.conv2(out) out = torch.flatten(out, 1) out = self.fc1(out) out = self.fc(out) return out