data_utils.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import ujson
  2. import numpy as np
  3. import os
  4. import torch
  5. # IMAGE_SIZE = 28
  6. # IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
  7. # NUM_CHANNELS = 1
  8. # IMAGE_SIZE_CIFAR = 32
  9. # NUM_CHANNELS_CIFAR = 3
  10. def batch_data(data, batch_size):
  11. '''
  12. data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client)
  13. returns x, y, which are both numpy array of length: batch_size
  14. '''
  15. data_x = data['x']
  16. data_y = data['y']
  17. # randomly shuffle data
  18. ran_state = np.random.get_state()
  19. np.random.shuffle(data_x)
  20. np.random.set_state(ran_state)
  21. np.random.shuffle(data_y)
  22. # loop through mini-batches
  23. for i in range(0, len(data_x), batch_size):
  24. batched_x = data_x[i:i+batch_size]
  25. batched_y = data_y[i:i+batch_size]
  26. yield (batched_x, batched_y)
  27. def get_random_batch_sample(data_x, data_y, batch_size):
  28. num_parts = len(data_x)//batch_size + 1
  29. if(len(data_x) > batch_size):
  30. batch_idx = np.random.choice(list(range(num_parts + 1)))
  31. sample_index = batch_idx*batch_size
  32. if(sample_index + batch_size > len(data_x)):
  33. return (data_x[sample_index:], data_y[sample_index:])
  34. else:
  35. return (data_x[sample_index: sample_index+batch_size], data_y[sample_index: sample_index+batch_size])
  36. else:
  37. return (data_x, data_y)
  38. def get_batch_sample(data, batch_size):
  39. data_x = data['x']
  40. data_y = data['y']
  41. # np.random.seed(100)
  42. ran_state = np.random.get_state()
  43. np.random.shuffle(data_x)
  44. np.random.set_state(ran_state)
  45. np.random.shuffle(data_y)
  46. batched_x = data_x[0:batch_size]
  47. batched_y = data_y[0:batch_size]
  48. return (batched_x, batched_y)
  49. def read_data(dataset, idx, is_train=True):
  50. if is_train:
  51. train_data_dir = os.path.join('../dataset', dataset, 'train/')
  52. train_file = train_data_dir + str(idx) + '.npz'
  53. with open(train_file, 'rb') as f:
  54. train_data = np.load(f, allow_pickle=True)['data'].tolist()
  55. return train_data
  56. else:
  57. test_data_dir = os.path.join('../dataset', dataset, 'test/')
  58. test_file = test_data_dir + str(idx) + '.npz'
  59. with open(test_file, 'rb') as f:
  60. test_data = np.load(f, allow_pickle=True)['data'].tolist()
  61. return test_data
  62. def read_client_data(dataset, idx, is_train=True):
  63. if is_train:
  64. train_data = read_data(dataset, idx, is_train)
  65. X_train = torch.Tensor(train_data['x']).type(torch.float32)
  66. y_train = torch.Tensor(train_data['y']).type(torch.int64)
  67. train_data = [(x, y) for x, y in zip(X_train, y_train)]
  68. return train_data
  69. else:
  70. test_data = read_data(dataset, idx, is_train)
  71. X_test = torch.Tensor(test_data['x']).type(torch.float32)
  72. y_test = torch.Tensor(test_data['y']).type(torch.int64)
  73. test_data = [(x, y) for x, y in zip(X_test, y_test)]
  74. return test_data