dataset.py 4.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import logging
  2. import os
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. from easyfl.datasets import FederatedTensorDataset
  6. from easyfl.datasets.data import CIFAR100
  7. from easyfl.datasets.simulation import data_simulation
  8. from easyfl.datasets.utils.util import save_dict, load_dict
  9. from utils import get_transformation
  10. logger = logging.getLogger(__name__)
  11. def semi_supervised_preprocess(dataset, num_of_client, split_type, weights, alpha, min_size, class_per_client,
  12. label_ratio=0.01):
  13. setting = f"{dataset}_{split_type}_{num_of_client}_{min_size}_{class_per_client}_{alpha}_{0}_{label_ratio}"
  14. data_path = f"./data/{dataset}"
  15. data_folder = os.path.join(data_path, setting)
  16. if not os.path.exists(data_folder):
  17. os.makedirs(data_folder)
  18. train_path = os.path.join(data_folder, "train")
  19. test_path = os.path.join(data_folder, "test")
  20. labeled_path = os.path.join(data_folder, "labeled")
  21. if os.path.exists(train_path):
  22. print("Load existing data")
  23. return load_dict(train_path), load_dict(test_path), load_dict(labeled_path)
  24. if dataset == CIFAR100:
  25. train_set = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True)
  26. test_set = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True)
  27. else:
  28. train_set = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True)
  29. test_set = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True)
  30. train_size = len(train_set.data)
  31. label_size = int(train_size * label_ratio)
  32. labeled_data = {
  33. 'x': train_set.data[:label_size],
  34. 'y': train_set.targets[:label_size],
  35. }
  36. train_data = {
  37. 'x': train_set.data[label_size:],
  38. 'y': train_set.targets[label_size:],
  39. }
  40. test_data = {
  41. 'x': test_set.data,
  42. 'y': test_set.targets,
  43. }
  44. print(f"{dataset} data simulation begins")
  45. _, train_data = data_simulation(train_data['x'],
  46. train_data['y'],
  47. num_of_client,
  48. split_type,
  49. weights,
  50. alpha,
  51. min_size,
  52. class_per_client)
  53. print(f"{dataset} data simulation is done")
  54. save_dict(train_data, train_path)
  55. save_dict(test_data, test_path)
  56. save_dict(labeled_data, labeled_path)
  57. return train_data, test_data, labeled_data
  58. def get_semi_supervised_dataset(dataset, num_of_client, split_type, class_per_client, label_ratio=0.01, image_size=32,
  59. gaussian=False):
  60. train_data, test_data, labeled_data = semi_supervised_preprocess(dataset, num_of_client, split_type, None, 0.5, 10,
  61. class_per_client, label_ratio)
  62. fine_tune_transform = transforms.Compose([
  63. torchvision.transforms.ToPILImage(mode='RGB'),
  64. torchvision.transforms.Resize(size=image_size),
  65. torchvision.transforms.ToTensor(),
  66. ])
  67. train_data = FederatedTensorDataset(train_data,
  68. simulated=True,
  69. do_simulate=False,
  70. process_x=None,
  71. process_y=None,
  72. transform=get_transformation("byol")(image_size, gaussian))
  73. test_data = FederatedTensorDataset(test_data,
  74. simulated=False,
  75. do_simulate=False,
  76. process_x=None,
  77. process_y=None,
  78. transform=get_transformation("byol")(image_size, gaussian).test_transform)
  79. labeled_data = FederatedTensorDataset(labeled_data,
  80. simulated=False,
  81. do_simulate=False,
  82. process_x=None,
  83. process_y=None,
  84. transform=fine_tune_transform)
  85. return train_data, test_data, labeled_data