generate_office_home.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # -*- coding: utf-8 -*-
  2. # @Author: Jun Luo
  3. # @Date: 2022-02-28 10:40:03
  4. # @Last Modified by: Jun Luo
  5. # @Last Modified time: 2022-02-28 10:40:03
  6. import numpy as np
  7. from sklearn.model_selection import train_test_split
  8. import os
  9. import sys
  10. import glob
  11. import torch
  12. import torchvision
  13. import torchvision.transforms as transforms
  14. from PIL import Image
  15. import shutil
  16. from torch.utils.data import Dataset, DataLoader
  17. ALPHA = 1.0
  18. N_CLIENTS = 20
  19. TEST_PORTION = 0.15
  20. SEED = 42
  21. SET_THRESHOLD = 20
  22. IMAGE_SIZE = 32
  23. N_CLASSES = 65
  24. IMAGE_SRC = "./Office-home-raw/"
  25. SAVE_FOLDER = f"./Office-home{N_CLIENTS}/"
  26. class ImageDatasetFromFileNames(Dataset):
  27. def __init__(self, fns, labels, transform=None, target_transform=None):
  28. self.fns = fns
  29. self.labels = labels
  30. self.transform = transform
  31. self.target_transform = target_transform
  32. def __getitem__(self, index):
  33. x = Image.open(self.fns[index])
  34. y = self.labels[index]
  35. if self.transform is not None:
  36. x = self.transform(x)
  37. if self.target_transform is not None:
  38. y = self.target_transform(y)
  39. return x, y
  40. def __len__(self):
  41. return len(self.labels)
  42. def dirichletSplit(alpha=10, n_clients=10, n_classes=10):
  43. return np.random.dirichlet(n_clients * [alpha], n_classes)
  44. def isNegligible(partitions, counts, THRESHOLD=2):
  45. s = np.matmul(partitions.T, counts)
  46. return (s < THRESHOLD).any()
  47. def split2clientsofficehome(x_fns, ys, stats, partitions, client_idx_offset=0, verbose=False):
  48. print("==> splitting dataset into clients' own datasets")
  49. n_classes, n_clients = partitions.shape
  50. splits = [] # n_classes * n_clients
  51. for i in range(n_classes):
  52. indices = np.where(ys == i)[0]
  53. np.random.shuffle(indices)
  54. cuts = np.cumsum(np.round_(partitions[i] * stats[str(i)]).astype(int))
  55. cuts = np.clip(cuts, 0, stats[str(i)])
  56. cuts[-1] = stats[str(i)]
  57. splits.append(np.split(indices, cuts))
  58. clients = []
  59. for i in range(n_clients):
  60. indices = np.concatenate([splits[j][i] for j in range(n_classes)], axis=0)
  61. dset = [x_fns[indices], ys[indices]]
  62. clients.append(dset)
  63. if verbose:
  64. print("\tclient %03d has" % (client_idx_offset+i+1), len(dset[0]), "images")
  65. return clients
  66. def get_immediate_subdirectories(a_dir):
  67. return [name for name in os.listdir(a_dir)
  68. if os.path.isdir(os.path.join(a_dir, name))]
  69. if __name__ == "__main__":
  70. np.random.seed(SEED)
  71. styles = ["Art", "Clipart", "Product", "Real World"]
  72. assert N_CLIENTS % 4 == 0, "### For Office-Home dataset, N_CLIENTS must be a multiple of 4...\nPlease change N_CLIENTS..."
  73. N_CLIENTS_PER_STYLE = N_CLIENTS // len(styles)
  74. cls_names = []
  75. for fn in get_immediate_subdirectories(IMAGE_SRC + styles[0]):
  76. cls_names.append(os.path.split(fn)[1])
  77. idx2clsname = {i: name for i, name in enumerate(cls_names)}
  78. get_cls_folder = lambda style, cls_n: os.path.join(IMAGE_SRC, style, cls_n)
  79. def get_dataset(dir, style):
  80. x_fns = []
  81. ys = []
  82. stats_dict = {}
  83. stats_list = []
  84. for i in range(N_CLASSES):
  85. cls_name = idx2clsname[i]
  86. x_for_cls = list(glob.glob(os.path.join(dir, style, cls_name, "*.jpg")))
  87. x_fns += x_for_cls
  88. ys += [i for _ in range(len(x_for_cls))]
  89. stats_dict[str(i)] = len(x_for_cls)
  90. stats_list.append(len(x_for_cls))
  91. return np.array(x_fns), np.array(ys), stats_dict, np.array(stats_list)
  92. clients = []
  93. for style_idx, style in enumerate(styles):
  94. dataset_style_fns, dataset_style_labels, dataset_stats_dict, dataset_stats_list = get_dataset(IMAGE_SRC, style)
  95. # print(len(dataset_style_fns), len(dataset_style_labels), np.sum(list(dataset_stats.values())))
  96. partitions = np.zeros((N_CLASSES, N_CLIENTS_PER_STYLE))
  97. i = 0
  98. while isNegligible(partitions, dataset_stats_list, SET_THRESHOLD/TEST_PORTION):
  99. partitions = dirichletSplit(alpha=ALPHA, n_clients=N_CLIENTS_PER_STYLE, n_classes=N_CLASSES)
  100. i += 1
  101. print(f"==> partitioning for the {i}th time (client dataset size >= {SET_THRESHOLD})")
  102. clients += split2clientsofficehome(dataset_style_fns,
  103. dataset_style_labels,
  104. dataset_stats_dict,
  105. partitions,
  106. client_idx_offset=style_idx*N_CLIENTS_PER_STYLE,
  107. verbose=True)
  108. # print()
  109. # print(np.sum([len(c[0]) for c in clients]))
  110. transform = transforms.Compose(
  111. [transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  112. if not os.path.exists(f"{SAVE_FOLDER}train/"):
  113. os.makedirs(f"{SAVE_FOLDER}train/")
  114. if not os.path.exists(f"{SAVE_FOLDER}test/"):
  115. os.makedirs(f"{SAVE_FOLDER}test/")
  116. for client_idx, (clt_x_fns, clt_ys) in enumerate(clients):
  117. print("==> saving (to %s) for client [%3d/%3d]" % (SAVE_FOLDER, client_idx+1, N_CLIENTS))
  118. # split train, val, test
  119. try:
  120. X_train_fns, X_test_fns, y_train, y_test = train_test_split(
  121. clt_x_fns, clt_ys, test_size=TEST_PORTION, random_state=SEED, stratify=clt_ys)
  122. except ValueError:
  123. X_train_fns, X_test_fns, y_train, y_test = train_test_split(
  124. clt_x_fns, clt_ys, test_size=TEST_PORTION, random_state=SEED)
  125. trainset = ImageDatasetFromFileNames(X_train_fns, y_train, transform=transform)
  126. testset = ImageDatasetFromFileNames(X_test_fns, y_test, transform=transform)
  127. trainloader = torch.utils.data.DataLoader(
  128. trainset, batch_size=len(trainset), shuffle=False)
  129. testloader = torch.utils.data.DataLoader(
  130. testset, batch_size=len(testset), shuffle=False)
  131. xs_train, ys_train = next(iter(trainloader))
  132. xs_test, ys_test = next(iter(testloader))
  133. train_dict = {"x": xs_train.numpy(), "y": ys_train.numpy()}
  134. test_dict = {"x": xs_test.numpy(), "y": ys_test.numpy()}
  135. # save
  136. for data_dict, npz_fn in [(train_dict, SAVE_FOLDER+f"train/{client_idx}.npz"), (test_dict, SAVE_FOLDER+f"test/{client_idx}.npz")]:
  137. with open(npz_fn, "wb") as f:
  138. np.savez_compressed(f, data=data_dict)
  139. print("\n==> finished saving all npz images.")