""" These codes are adopted from LEAF with some modifications. Splits data into train and test sets. """ import json import logging import os import random import sys import time from collections import OrderedDict from easyfl.datasets.utils.constants import SEED_FILES logger = logging.getLogger(__name__) def create_jsons_for(dir, setting_folder, user_files, which_set, max_users, include_hierarchy, subdir, arg_label): """Used in split-by-user case""" user_count = 0 json_index = 0 users = [] if include_hierarchy: hierarchies = [] else: hierarchies = None num_samples = [] user_data = {} for (i, t) in enumerate(user_files): if include_hierarchy: (u, h, ns, f) = t else: (u, ns, f) = t file_dir = os.path.join(subdir, f) with open(file_dir, 'r') as inf: data = json.load(inf) users.append(u) if include_hierarchy: hierarchies.append(h) num_samples.append(ns) user_data[u] = data['user_data'][u] user_count += 1 if (user_count == max_users) or (i == len(user_files) - 1): all_data = {} all_data['users'] = users if include_hierarchy: all_data['hierarchies'] = hierarchies all_data['num_samples'] = num_samples all_data['user_data'] = user_data data_i = f.find('data') num_i = data_i + 5 num_to_end = f[num_i:] param_i = num_to_end.find('_') param_to_end = '.json' if param_i != -1: param_to_end = num_to_end[param_i:] nf = "{}_{}{}".format(f[:(num_i - 1)], json_index, param_to_end) file_name = '{}_{}_{}.json'.format((nf[:-5]), which_set, arg_label) ouf_dir = os.path.join(dir, setting_folder, which_set, file_name) logger.info('writing {}'.format(file_name)) with open(ouf_dir, 'w') as outfile: json.dump(all_data, outfile) user_count = 0 json_index += 1 users = [] if include_hierarchy: hierarchies = [] num_samples = [] user_data = {} def split_train_test(setting_folder, metafile, name, user, frac, seed): logger.info("------------------------------") logger.info("generating training and test sets") parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) dir = os.path.join(parent_path, name, 'data') subdir = os.path.join(dir, setting_folder, 'rem_user_data') files = [] if os.path.exists(subdir): files = os.listdir(subdir) if len(files) == 0: subdir = os.path.join(dir, setting_folder, 'sampled_data') if os.path.exists(subdir): files = os.listdir(subdir) if len(files) == 0: subdir = os.path.join(dir, 'all_data') files = os.listdir(subdir) files = [f for f in files if f.endswith('.json')] rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time())) rng = random.Random(rng_seed) if metafile is not None: seed_fname = os.path.join(metafile, SEED_FILES['split']) with open(seed_fname, 'w+') as f: f.write("# split_seed used by sampling script - supply as " "--spltseed to preprocess.sh or --seed to utils/split_data.py\n") f.write(str(rng_seed)) logger.info("- random seed written out to {file}".format(file=seed_fname)) else: logger.info("- using random seed '{seed}' for sampling".format(seed=rng_seed)) arg_label = str(frac) arg_label = arg_label[2:] # check if data contains information on hierarchies file_dir = os.path.join(subdir, files[0]) with open(file_dir, 'r') as inf: data = json.load(inf) include_hierarchy = 'hierarchies' in data if (user): logger.info("splitting data by user") # 1 pass through all the json files to instantiate arr # containing all possible (user, .json file name) tuples user_files = [] for f in files: file_dir = os.path.join(subdir, f) with open(file_dir, 'r') as inf: # Load data into an OrderedDict, to prevent ordering changes # and enable reproducibility data = json.load(inf, object_pairs_hook=OrderedDict) if include_hierarchy: user_files.extend([(u, h, ns, f) for (u, h, ns) in zip(data['users'], data['hierarchies'], data['num_samples'])]) else: user_files.extend([(u, ns, f) for (u, ns) in zip(data['users'], data['num_samples'])]) # randomly sample from user_files to pick training set users num_users = len(user_files) num_train_users = int(frac * num_users) indices = [i for i in range(num_users)] train_indices = rng.sample(indices, num_train_users) train_blist = [False for i in range(num_users)] for i in train_indices: train_blist[i] = True train_user_files = [] test_user_files = [] for i in range(num_users): if (train_blist[i]): train_user_files.append(user_files[i]) else: test_user_files.append(user_files[i]) max_users = sys.maxsize if name == 'femnist': max_users = 50 # max number of users per json file create_jsons_for(dir, setting_folder, train_user_files, 'train', max_users, include_hierarchy, subdir, arg_label) create_jsons_for(dir, setting_folder, test_user_files, 'test', max_users, include_hierarchy, subdir, arg_label) else: logger.info("splitting data by sample") for f in files: file_dir = os.path.join(subdir, f) with open(file_dir, 'r') as inf: # Load data into an OrderedDict, to prevent ordering changes # and enable reproducibility data = json.load(inf, object_pairs_hook=OrderedDict) num_samples_train = [] user_data_train = {} num_samples_test = [] user_data_test = {} user_indices = [] # indices of users in data['users'] that are not deleted for i, u in enumerate(data['users']): user_data_train[u] = {'x': [], 'y': []} user_data_test[u] = {'x': [], 'y': []} curr_num_samples = len(data['user_data'][u]['y']) if curr_num_samples >= 2: user_indices.append(i) # ensures number of train and test samples both >= 1 num_train_samples = max(1, int(frac * curr_num_samples)) if curr_num_samples == 2: num_train_samples = 1 num_test_samples = curr_num_samples - num_train_samples num_samples_train.append(num_train_samples) num_samples_test.append(num_test_samples) indices = [j for j in range(curr_num_samples)] train_indices = rng.sample(indices, num_train_samples) train_blist = [False for _ in range(curr_num_samples)] for j in train_indices: train_blist[j] = True for j in range(curr_num_samples): if (train_blist[j]): user_data_train[u]['x'].append(data['user_data'][u]['x'][j]) user_data_train[u]['y'].append(data['user_data'][u]['y'][j]) else: user_data_test[u]['x'].append(data['user_data'][u]['x'][j]) user_data_test[u]['y'].append(data['user_data'][u]['y'][j]) users = [data['users'][i] for i in user_indices] all_data_train = {} all_data_train['users'] = users all_data_train['num_samples'] = num_samples_train all_data_train['user_data'] = user_data_train all_data_test = {} all_data_test['users'] = users all_data_test['num_samples'] = num_samples_test all_data_test['user_data'] = user_data_test if include_hierarchy: all_data_train['hierarchies'] = data['hierarchies'] all_data_test['hierarchies'] = data['hierarchies'] file_name_train = '{}_train_{}.json'.format((f[:-5]), arg_label) file_name_test = '{}_test_{}.json'.format((f[:-5]), arg_label) ouf_dir_train = os.path.join(dir, setting_folder, 'train', file_name_train) ouf_dir_test = os.path.join(dir, setting_folder, 'test', file_name_test) logger.info("writing {}".format(file_name_train)) with open(ouf_dir_train, 'w') as outfile: json.dump(all_data_train, outfile) logger.info("writing {}".format(file_name_test)) with open(ouf_dir_test, 'w') as outfile: json.dump(all_data_test, outfile)