generate_cifar10.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import numpy as np
  2. import os
  3. import sys
  4. import random
  5. import torch
  6. import torchvision
  7. import torchvision.transforms as transforms
  8. from utils.dataset_utils import check, separate_data, split_data, save_file
  9. random.seed(1)
  10. np.random.seed(1)
  11. num_clients = 25
  12. num_classes = 10
  13. dir_path = "cifar10/"
  14. # Allocate data to users
  15. def generate_cifar10(dir_path, num_clients, num_classes, niid, balance, partition):
  16. if not os.path.exists(dir_path):
  17. os.makedirs(dir_path)
  18. # Setup directory for train/test data
  19. config_path = dir_path + "config.json"
  20. train_path = dir_path + "train/"
  21. test_path = dir_path + "test/"
  22. if check(config_path, train_path, test_path, num_clients, num_classes, niid, balance, partition):
  23. return
  24. # Get Cifar10 data
  25. transform = transforms.Compose(
  26. [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  27. trainset = torchvision.datasets.CIFAR10(
  28. root=dir_path+"rawdata", train=True, download=True, transform=transform)
  29. testset = torchvision.datasets.CIFAR10(
  30. root=dir_path+"rawdata", train=False, download=True, transform=transform)
  31. trainloader = torch.utils.data.DataLoader(
  32. trainset, batch_size=len(trainset.data), shuffle=False)
  33. testloader = torch.utils.data.DataLoader(
  34. testset, batch_size=len(testset.data), shuffle=False)
  35. for _, train_data in enumerate(trainloader, 0):
  36. trainset.data, trainset.targets = train_data
  37. for _, test_data in enumerate(testloader, 0):
  38. testset.data, testset.targets = test_data
  39. dataset_image = []
  40. dataset_label = []
  41. dataset_image.extend(trainset.data.cpu().detach().numpy())
  42. dataset_image.extend(testset.data.cpu().detach().numpy())
  43. dataset_label.extend(trainset.targets.cpu().detach().numpy())
  44. dataset_label.extend(testset.targets.cpu().detach().numpy())
  45. dataset_image = np.array(dataset_image)
  46. dataset_label = np.array(dataset_label)
  47. X, y, statistic = separate_data((dataset_image, dataset_label), num_clients, num_classes,
  48. niid, balance, partition)
  49. train_data, test_data = split_data(X, y)
  50. save_file(config_path, train_path, test_path, train_data, test_data, num_clients, num_classes,
  51. statistic, niid, balance, partition)
  52. if __name__ == "__main__":
  53. niid = True if sys.argv[1] == "noniid" else False
  54. balance = True if sys.argv[2] == "balance" else False
  55. partition = sys.argv[3] if sys.argv[3] != "-" else None
  56. generate_cifar10(dir_path, num_clients, num_classes, niid, balance, partition)