1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import ujson
- import numpy as np
- import os
- import torch
- # IMAGE_SIZE = 28
- # IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
- # NUM_CHANNELS = 1
- # IMAGE_SIZE_CIFAR = 32
- # NUM_CHANNELS_CIFAR = 3
- def batch_data(data, batch_size):
- '''
- data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client)
- returns x, y, which are both numpy array of length: batch_size
- '''
- data_x = data['x']
- data_y = data['y']
- # randomly shuffle data
- ran_state = np.random.get_state()
- np.random.shuffle(data_x)
- np.random.set_state(ran_state)
- np.random.shuffle(data_y)
- # loop through mini-batches
- for i in range(0, len(data_x), batch_size):
- batched_x = data_x[i:i+batch_size]
- batched_y = data_y[i:i+batch_size]
- yield (batched_x, batched_y)
- def get_random_batch_sample(data_x, data_y, batch_size):
- num_parts = len(data_x)//batch_size + 1
- if(len(data_x) > batch_size):
- batch_idx = np.random.choice(list(range(num_parts + 1)))
- sample_index = batch_idx*batch_size
- if(sample_index + batch_size > len(data_x)):
- return (data_x[sample_index:], data_y[sample_index:])
- else:
- return (data_x[sample_index: sample_index+batch_size], data_y[sample_index: sample_index+batch_size])
- else:
- return (data_x, data_y)
- def get_batch_sample(data, batch_size):
- data_x = data['x']
- data_y = data['y']
- # np.random.seed(100)
- ran_state = np.random.get_state()
- np.random.shuffle(data_x)
- np.random.set_state(ran_state)
- np.random.shuffle(data_y)
- batched_x = data_x[0:batch_size]
- batched_y = data_y[0:batch_size]
- return (batched_x, batched_y)
- def read_data(dataset, idx, is_train=True):
- if is_train:
- train_data_dir = os.path.join('../dataset', dataset, 'train/')
- train_file = train_data_dir + str(idx) + '.npz'
- with open(train_file, 'rb') as f:
- train_data = np.load(f, allow_pickle=True)['data'].tolist()
- return train_data
- else:
- test_data_dir = os.path.join('../dataset', dataset, 'test/')
- test_file = test_data_dir + str(idx) + '.npz'
- with open(test_file, 'rb') as f:
- test_data = np.load(f, allow_pickle=True)['data'].tolist()
- return test_data
- def read_client_data(dataset, idx, is_train=True):
- if is_train:
- train_data = read_data(dataset, idx, is_train)
- X_train = torch.Tensor(train_data['x']).type(torch.float32)
- y_train = torch.Tensor(train_data['y']).type(torch.int64)
- train_data = [(x, y) for x, y in zip(X_train, y_train)]
- return train_data
- else:
- test_data = read_data(dataset, idx, is_train)
- X_test = torch.Tensor(test_data['x']).type(torch.float32)
- y_test = torch.Tensor(test_data['y']).type(torch.int64)
- test_data = [(x, y) for x, y in zip(X_test, y_test)]
- return test_data
|