dataset_utils.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import os
  2. import ujson
  3. import numpy as np
  4. import gc
  5. from sklearn.model_selection import train_test_split
  6. batch_size = 10
  7. train_size = 0.75 # merge original training set and test set, then split it manually.
  8. least_samples = batch_size / (1-train_size) # least samples for each client
  9. alpha = 0.3 # for Dirichlet distribution
  10. def check(config_path, train_path, test_path, num_clients, num_classes, niid=False,
  11. balance=True, partition=None):
  12. # check existing dataset
  13. if os.path.exists(config_path):
  14. with open(config_path, 'r') as f:
  15. config = ujson.load(f)
  16. if config['num_clients'] == num_clients and \
  17. config['num_classes'] == num_classes and \
  18. config['non_iid'] == niid and \
  19. config['balance'] == balance and \
  20. config['partition'] == partition and \
  21. config['alpha'] == alpha and \
  22. config['batch_size'] == batch_size:
  23. print("\nDataset already generated.\n")
  24. return True
  25. dir_path = os.path.dirname(train_path)
  26. if not os.path.exists(dir_path):
  27. os.makedirs(dir_path)
  28. dir_path = os.path.dirname(test_path)
  29. if not os.path.exists(dir_path):
  30. os.makedirs(dir_path)
  31. return False
  32. def separate_data(data, num_clients, num_classes, niid=False, balance=False, partition=None, class_per_client=2):
  33. X = [[] for _ in range(num_clients)]
  34. y = [[] for _ in range(num_clients)]
  35. statistic = [[] for _ in range(num_clients)]
  36. dataset_content, dataset_label = data
  37. dataidx_map = {}
  38. if not niid:
  39. partition = 'pat'
  40. class_per_client = num_classes
  41. if partition == 'pat':
  42. idxs = np.array(range(len(dataset_label)))
  43. idx_for_each_class = []
  44. for i in range(num_classes):
  45. idx_for_each_class.append(idxs[dataset_label == i])
  46. class_num_per_client = [class_per_client for _ in range(num_clients)]
  47. for i in range(num_classes):
  48. selected_clients = []
  49. for client in range(num_clients):
  50. if class_num_per_client[client] > 0:
  51. selected_clients.append(client)
  52. selected_clients = selected_clients[:int(num_clients/num_classes*class_per_client)]
  53. num_all_samples = len(idx_for_each_class[i])
  54. num_selected_clients = len(selected_clients)
  55. num_per = num_all_samples / num_selected_clients
  56. if balance:
  57. num_samples = [int(num_per) for _ in range(num_selected_clients-1)]
  58. else:
  59. num_samples = np.random.randint(max(num_per/10, least_samples/num_classes), num_per, num_selected_clients-1).tolist()
  60. num_samples.append(num_all_samples-sum(num_samples))
  61. idx = 0
  62. for client, num_sample in zip(selected_clients, num_samples):
  63. if client not in dataidx_map.keys():
  64. dataidx_map[client] = idx_for_each_class[i][idx:idx+num_sample]
  65. else:
  66. dataidx_map[client] = np.append(dataidx_map[client], idx_for_each_class[i][idx:idx+num_sample], axis=0)
  67. idx += num_sample
  68. class_num_per_client[client] -= 1
  69. elif partition == "dir":
  70. # https://github.com/IBM/probabilistic-federated-neural-matching/blob/master/experiment.py
  71. min_size = 0
  72. K = num_classes
  73. N = len(dataset_label)
  74. while min_size < least_samples:
  75. idx_batch = [[] for _ in range(num_clients)]
  76. for k in range(K):
  77. idx_k = np.where(dataset_label == k)[0]
  78. np.random.shuffle(idx_k)
  79. proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
  80. proportions = np.array([p*(len(idx_j)<N/num_clients) for p,idx_j in zip(proportions,idx_batch)])
  81. proportions = proportions/proportions.sum()
  82. proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
  83. idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
  84. min_size = min([len(idx_j) for idx_j in idx_batch])
  85. for j in range(num_clients):
  86. dataidx_map[j] = idx_batch[j]
  87. else:
  88. raise NotImplementedError
  89. # assign data
  90. for client in range(num_clients):
  91. idxs = dataidx_map[client]
  92. X[client] = dataset_content[idxs]
  93. y[client] = dataset_label[idxs]
  94. for i in np.unique(y[client]):
  95. statistic[client].append((int(i), int(sum(y[client]==i))))
  96. del data
  97. # gc.collect()
  98. for client in range(num_clients):
  99. print(f"Client {client}\t Size of data: {len(X[client])}\t Labels: ", np.unique(y[client]))
  100. print(f"\t\t Samples of labels: ", [i for i in statistic[client]])
  101. print("-" * 50)
  102. return X, y, statistic
  103. def split_data(X, y):
  104. # Split dataset
  105. train_data, test_data = [], []
  106. num_samples = {'train':[], 'test':[]}
  107. for i in range(len(y)):
  108. unique, count = np.unique(y[i], return_counts=True)
  109. if min(count) > 1:
  110. X_train, X_test, y_train, y_test = train_test_split(
  111. X[i], y[i], train_size=train_size, shuffle=True)
  112. else:
  113. X_train, X_test, y_train, y_test = train_test_split(
  114. X[i], y[i], train_size=train_size, shuffle=True)
  115. train_data.append({'x': X_train, 'y': y_train})
  116. num_samples['train'].append(len(y_train))
  117. test_data.append({'x': X_test, 'y': y_test})
  118. num_samples['test'].append(len(y_test))
  119. print("Total number of samples:", sum(num_samples['train'] + num_samples['test']))
  120. print("The number of train samples:", num_samples['train'])
  121. print("The number of test samples:", num_samples['test'])
  122. print()
  123. del X, y
  124. # gc.collect()
  125. return train_data, test_data
  126. def save_file(config_path, train_path, test_path, train_data, test_data, num_clients,
  127. num_classes, statistic, niid=False, balance=True, partition=None):
  128. config = {
  129. 'num_clients': num_clients,
  130. 'num_classes': num_classes,
  131. 'non_iid': niid,
  132. 'balance': balance,
  133. 'partition': partition,
  134. 'Size of samples for labels in clients': statistic,
  135. 'alpha': alpha,
  136. 'batch_size': batch_size,
  137. }
  138. # gc.collect()
  139. print("Saving to disk.\n")
  140. for idx, train_dict in enumerate(train_data):
  141. with open(train_path + str(idx) + '.npz', 'wb') as f:
  142. np.savez_compressed(f, data=train_dict)
  143. for idx, test_dict in enumerate(test_data):
  144. with open(test_path + str(idx) + '.npz', 'wb') as f:
  145. np.savez_compressed(f, data=test_dict)
  146. with open(config_path, 'w') as f:
  147. ujson.dump(config, f)
  148. print("Finish generating dataset.\n")