generate_medmnist.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import numpy as np
  2. import os
  3. import sys
  4. import random
  5. import torch
  6. import torchvision
  7. import torchvision.transforms as transforms
  8. from utils.dataset_utils import check, separate_data, split_data, save_file
  9. from torchvision.datasets import ImageFolder, DatasetFolder
  10. # medmnist
  11. import medmnist
  12. from medmnist import INFO
  13. random.seed(1)
  14. np.random.seed(1)
  15. num_clients = 25
  16. dir_path = f"organamnist{num_clients}/"
  17. # medmnist
  18. data_flag = ("".join([i for i in dir_path if i.isalpha()])).lower()
  19. # data_flag = 'breastmnist'
  20. download = False
  21. info = INFO[data_flag]
  22. task = info['task']
  23. n_channels = info['n_channels']
  24. num_classes = len(info['label'])
  25. DataClass = getattr(medmnist, info['python_class'])
  26. # Allocate data to users
  27. def generate_dataset(dir_path, num_clients, num_classes, niid, balance, partition):
  28. if not os.path.exists(dir_path):
  29. os.makedirs(dir_path)
  30. # Setup directory for train/test data
  31. config_path = dir_path + "config.json"
  32. train_path = dir_path + "train/"
  33. test_path = dir_path + "test/"
  34. if check(config_path, train_path, test_path, num_clients, num_classes, niid, balance, partition):
  35. return
  36. # Get data
  37. transform = transforms.Compose(
  38. [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
  39. trainset = DataClass(split='train', transform=transform, download=download)
  40. valset = DataClass(split='val', transform=transform, download=download)
  41. testset = DataClass( split='test', transform=transform, download=download)
  42. trainloader = torch.utils.data.DataLoader(
  43. trainset, batch_size=len(trainset), shuffle=True)
  44. valloader = torch.utils.data.DataLoader(
  45. valset, batch_size=len(valset), shuffle=True)
  46. testloader = torch.utils.data.DataLoader(
  47. testset, batch_size=len(testset), shuffle=True)
  48. for _, train_data in enumerate(trainloader, 0):
  49. trainset.data, trainset.targets = train_data
  50. for _, val_data in enumerate(valloader, 0):
  51. valset.data, valset.targets = val_data
  52. for _, test_data in enumerate(testloader, 0):
  53. testset.data, testset.targets = test_data
  54. dataset_image = []
  55. dataset_label = []
  56. dataset_image.extend(trainset.data.cpu().detach().numpy())
  57. dataset_image.extend(valset.data.cpu().detach().numpy())
  58. dataset_image.extend(testset.data.cpu().detach().numpy())
  59. dataset_label.extend(trainset.targets.cpu().detach().numpy())
  60. dataset_label.extend(valset.targets.cpu().detach().numpy())
  61. dataset_label.extend(testset.targets.cpu().detach().numpy())
  62. dataset_image = np.array(dataset_image)
  63. dataset_label = np.array(dataset_label).reshape(-1)
  64. X, y, statistic = separate_data((dataset_image, dataset_label), num_clients, num_classes,
  65. niid, balance, partition)
  66. train_data, test_data = split_data(X, y)
  67. save_file(config_path, train_path, test_path, train_data, test_data, num_clients, num_classes,
  68. statistic, niid, balance, partition)
  69. if __name__ == "__main__":
  70. niid = True if sys.argv[1] == "noniid" else False
  71. balance = True if sys.argv[2] == "balance" else False
  72. partition = sys.argv[3] if sys.argv[3] != "-" else None
  73. generate_dataset(dir_path, num_clients, num_classes, niid, balance, partition)