dataset_registration.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import torch.nn.functional as F
  2. from torch import nn
  3. from torchvision import transforms
  4. import easyfl
  5. from easyfl.datasets import FederatedImageDataset
  6. from easyfl.models import BaseModel
  7. class TestModel(BaseModel):
  8. def __init__(self):
  9. super(TestModel, self).__init__()
  10. self.conv1 = nn.Conv2d(3, 32, 224, padding=(2, 2))
  11. self.conv2 = nn.Conv2d(32, 64, 5, padding=(2, 2))
  12. self.fc1 = nn.Linear(64, 128)
  13. self.fc2 = nn.Linear(128, 42)
  14. def forward(self, x):
  15. x = F.relu(self.conv1(x))
  16. x = F.max_pool2d(x, 2, 2)
  17. x = F.relu(self.conv2(x))
  18. x = F.max_pool2d(x, 2, 2)
  19. x = x.view(-1, 64)
  20. x = F.relu(self.fc1(x))
  21. x = self.fc2(x)
  22. return x
  23. default_transfrom = transforms.Compose([
  24. # images are of different size, reshape them to same size. Up to you to decide what size to use.
  25. transforms.Resize((224, 224)),
  26. transforms.ToTensor(), # convert image to torch tensor
  27. transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # normalizing images helps improve convergence
  28. root = "/Users/sg0014000170m/Downloads/shopee-product-detection-student/train/train/train"
  29. train_data = FederatedImageDataset(root,
  30. simulated=False,
  31. do_simulate=True,
  32. transform=default_transfrom,
  33. num_of_clients=100)
  34. test_data = FederatedImageDataset(root,
  35. simulated=False,
  36. do_simulate=False,
  37. transform=default_transfrom,
  38. num_of_clients=100)
  39. easyfl.register_model(TestModel)
  40. easyfl.register_dataset(train_data, test_data)
  41. easyfl.init()
  42. easyfl.run()