123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- """
- These codes are adopted from LEAF with some modifications.
- Splits data into train and test sets.
- """
- import json
- import logging
- import os
- import random
- import sys
- import time
- from collections import OrderedDict
- from easyfl.datasets.utils.constants import SEED_FILES
- logger = logging.getLogger(__name__)
- def create_jsons_for(dir, setting_folder, user_files, which_set, max_users, include_hierarchy, subdir, arg_label):
- """Used in split-by-user case"""
- user_count = 0
- json_index = 0
- users = []
- if include_hierarchy:
- hierarchies = []
- else:
- hierarchies = None
- num_samples = []
- user_data = {}
- for (i, t) in enumerate(user_files):
- if include_hierarchy:
- (u, h, ns, f) = t
- else:
- (u, ns, f) = t
- file_dir = os.path.join(subdir, f)
- with open(file_dir, 'r') as inf:
- data = json.load(inf)
- users.append(u)
- if include_hierarchy:
- hierarchies.append(h)
- num_samples.append(ns)
- user_data[u] = data['user_data'][u]
- user_count += 1
- if (user_count == max_users) or (i == len(user_files) - 1):
- all_data = {}
- all_data['users'] = users
- if include_hierarchy:
- all_data['hierarchies'] = hierarchies
- all_data['num_samples'] = num_samples
- all_data['user_data'] = user_data
- data_i = f.find('data')
- num_i = data_i + 5
- num_to_end = f[num_i:]
- param_i = num_to_end.find('_')
- param_to_end = '.json'
- if param_i != -1:
- param_to_end = num_to_end[param_i:]
- nf = "{}_{}{}".format(f[:(num_i - 1)], json_index, param_to_end)
- file_name = '{}_{}_{}.json'.format((nf[:-5]), which_set, arg_label)
- ouf_dir = os.path.join(dir, setting_folder, which_set, file_name)
- logger.info('writing {}'.format(file_name))
- with open(ouf_dir, 'w') as outfile:
- json.dump(all_data, outfile)
- user_count = 0
- json_index += 1
- users = []
- if include_hierarchy:
- hierarchies = []
- num_samples = []
- user_data = {}
- def split_train_test(setting_folder, metafile, name, user, frac, seed):
- logger.info("------------------------------")
- logger.info("generating training and test sets")
- parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
- dir = os.path.join(parent_path, name, 'data')
- subdir = os.path.join(dir, setting_folder, 'rem_user_data')
- files = []
- if os.path.exists(subdir):
- files = os.listdir(subdir)
- if len(files) == 0:
- subdir = os.path.join(dir, setting_folder, 'sampled_data')
- if os.path.exists(subdir):
- files = os.listdir(subdir)
- if len(files) == 0:
- subdir = os.path.join(dir, 'all_data')
- files = os.listdir(subdir)
- files = [f for f in files if f.endswith('.json')]
- rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
- rng = random.Random(rng_seed)
- if metafile is not None:
- seed_fname = os.path.join(metafile, SEED_FILES['split'])
- with open(seed_fname, 'w+') as f:
- f.write("# split_seed used by sampling script - supply as "
- "--spltseed to preprocess.sh or --seed to utils/split_data.py\n")
- f.write(str(rng_seed))
- logger.info("- random seed written out to {file}".format(file=seed_fname))
- else:
- logger.info("- using random seed '{seed}' for sampling".format(seed=rng_seed))
- arg_label = str(frac)
- arg_label = arg_label[2:]
- # check if data contains information on hierarchies
- file_dir = os.path.join(subdir, files[0])
- with open(file_dir, 'r') as inf:
- data = json.load(inf)
- include_hierarchy = 'hierarchies' in data
- if (user):
- logger.info("splitting data by user")
- # 1 pass through all the json files to instantiate arr
- # containing all possible (user, .json file name) tuples
- user_files = []
- for f in files:
- file_dir = os.path.join(subdir, f)
- with open(file_dir, 'r') as inf:
- # Load data into an OrderedDict, to prevent ordering changes
- # and enable reproducibility
- data = json.load(inf, object_pairs_hook=OrderedDict)
- if include_hierarchy:
- user_files.extend([(u, h, ns, f) for (u, h, ns) in
- zip(data['users'], data['hierarchies'], data['num_samples'])])
- else:
- user_files.extend([(u, ns, f) for (u, ns) in
- zip(data['users'], data['num_samples'])])
- # randomly sample from user_files to pick training set users
- num_users = len(user_files)
- num_train_users = int(frac * num_users)
- indices = [i for i in range(num_users)]
- train_indices = rng.sample(indices, num_train_users)
- train_blist = [False for i in range(num_users)]
- for i in train_indices:
- train_blist[i] = True
- train_user_files = []
- test_user_files = []
- for i in range(num_users):
- if (train_blist[i]):
- train_user_files.append(user_files[i])
- else:
- test_user_files.append(user_files[i])
- max_users = sys.maxsize
- if name == 'femnist':
- max_users = 50 # max number of users per json file
- create_jsons_for(dir, setting_folder, train_user_files, 'train', max_users, include_hierarchy, subdir,
- arg_label)
- create_jsons_for(dir, setting_folder, test_user_files, 'test', max_users, include_hierarchy, subdir, arg_label)
- else:
- logger.info("splitting data by sample")
- for f in files:
- file_dir = os.path.join(subdir, f)
- with open(file_dir, 'r') as inf:
- # Load data into an OrderedDict, to prevent ordering changes
- # and enable reproducibility
- data = json.load(inf, object_pairs_hook=OrderedDict)
- num_samples_train = []
- user_data_train = {}
- num_samples_test = []
- user_data_test = {}
- user_indices = [] # indices of users in data['users'] that are not deleted
- for i, u in enumerate(data['users']):
- user_data_train[u] = {'x': [], 'y': []}
- user_data_test[u] = {'x': [], 'y': []}
- curr_num_samples = len(data['user_data'][u]['y'])
- if curr_num_samples >= 2:
- user_indices.append(i)
- # ensures number of train and test samples both >= 1
- num_train_samples = max(1, int(frac * curr_num_samples))
- if curr_num_samples == 2:
- num_train_samples = 1
- num_test_samples = curr_num_samples - num_train_samples
- num_samples_train.append(num_train_samples)
- num_samples_test.append(num_test_samples)
- indices = [j for j in range(curr_num_samples)]
- train_indices = rng.sample(indices, num_train_samples)
- train_blist = [False for _ in range(curr_num_samples)]
- for j in train_indices:
- train_blist[j] = True
- for j in range(curr_num_samples):
- if (train_blist[j]):
- user_data_train[u]['x'].append(data['user_data'][u]['x'][j])
- user_data_train[u]['y'].append(data['user_data'][u]['y'][j])
- else:
- user_data_test[u]['x'].append(data['user_data'][u]['x'][j])
- user_data_test[u]['y'].append(data['user_data'][u]['y'][j])
- users = [data['users'][i] for i in user_indices]
- all_data_train = {}
- all_data_train['users'] = users
- all_data_train['num_samples'] = num_samples_train
- all_data_train['user_data'] = user_data_train
- all_data_test = {}
- all_data_test['users'] = users
- all_data_test['num_samples'] = num_samples_test
- all_data_test['user_data'] = user_data_test
- if include_hierarchy:
- all_data_train['hierarchies'] = data['hierarchies']
- all_data_test['hierarchies'] = data['hierarchies']
- file_name_train = '{}_train_{}.json'.format((f[:-5]), arg_label)
- file_name_test = '{}_test_{}.json'.format((f[:-5]), arg_label)
- ouf_dir_train = os.path.join(dir, setting_folder, 'train', file_name_train)
- ouf_dir_test = os.path.join(dir, setting_folder, 'test', file_name_test)
- logger.info("writing {}".format(file_name_train))
- with open(ouf_dir_train, 'w') as outfile:
- json.dump(all_data_train, outfile)
- logger.info("writing {}".format(file_name_test))
- with open(ouf_dir_test, 'w') as outfile:
- json.dump(all_data_test, outfile)
|