"""
These codes are adopted from LEAF with some modifications.

Splits data into train and test sets.
"""

import json
import logging
import os
import random
import sys
import time
from collections import OrderedDict

from easyfl.datasets.utils.constants import SEED_FILES

logger = logging.getLogger(__name__)


def create_jsons_for(dir, setting_folder, user_files, which_set, max_users, include_hierarchy, subdir, arg_label):
    """Used in split-by-user case"""
    user_count = 0
    json_index = 0
    users = []
    if include_hierarchy:
        hierarchies = []
    else:
        hierarchies = None
    num_samples = []
    user_data = {}
    for (i, t) in enumerate(user_files):
        if include_hierarchy:
            (u, h, ns, f) = t
        else:
            (u, ns, f) = t

        file_dir = os.path.join(subdir, f)
        with open(file_dir, 'r') as inf:
            data = json.load(inf)

        users.append(u)
        if include_hierarchy:
            hierarchies.append(h)
        num_samples.append(ns)
        user_data[u] = data['user_data'][u]
        user_count += 1

        if (user_count == max_users) or (i == len(user_files) - 1):

            all_data = {}
            all_data['users'] = users
            if include_hierarchy:
                all_data['hierarchies'] = hierarchies
            all_data['num_samples'] = num_samples
            all_data['user_data'] = user_data

            data_i = f.find('data')
            num_i = data_i + 5
            num_to_end = f[num_i:]
            param_i = num_to_end.find('_')
            param_to_end = '.json'
            if param_i != -1:
                param_to_end = num_to_end[param_i:]
            nf = "{}_{}{}".format(f[:(num_i - 1)], json_index, param_to_end)
            file_name = '{}_{}_{}.json'.format((nf[:-5]), which_set, arg_label)
            ouf_dir = os.path.join(dir, setting_folder, which_set, file_name)

            logger.info('writing {}'.format(file_name))
            with open(ouf_dir, 'w') as outfile:
                json.dump(all_data, outfile)

            user_count = 0
            json_index += 1
            users = []
            if include_hierarchy:
                hierarchies = []
            num_samples = []
            user_data = {}


def split_train_test(setting_folder, metafile, name, user, frac, seed):
    logger.info("------------------------------")
    logger.info("generating training and test sets")

    parent_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
    dir = os.path.join(parent_path, name, 'data')
    subdir = os.path.join(dir, setting_folder, 'rem_user_data')
    files = []
    if os.path.exists(subdir):
        files = os.listdir(subdir)
    if len(files) == 0:
        subdir = os.path.join(dir, setting_folder, 'sampled_data')
        if os.path.exists(subdir):
            files = os.listdir(subdir)
    if len(files) == 0:
        subdir = os.path.join(dir, 'all_data')
        files = os.listdir(subdir)
    files = [f for f in files if f.endswith('.json')]

    rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
    rng = random.Random(rng_seed)
    if metafile is not None:
        seed_fname = os.path.join(metafile, SEED_FILES['split'])
        with open(seed_fname, 'w+') as f:
            f.write("# split_seed used by sampling script - supply as "
                    "--spltseed to preprocess.sh or --seed to utils/split_data.py\n")
            f.write(str(rng_seed))
        logger.info("- random seed written out to {file}".format(file=seed_fname))
    else:
        logger.info("- using random seed '{seed}' for sampling".format(seed=rng_seed))

    arg_label = str(frac)
    arg_label = arg_label[2:]

    # check if data contains information on hierarchies
    file_dir = os.path.join(subdir, files[0])
    with open(file_dir, 'r') as inf:
        data = json.load(inf)
    include_hierarchy = 'hierarchies' in data

    if (user):
        logger.info("splitting data by user")

        # 1 pass through all the json files to instantiate arr
        # containing all possible (user, .json file name) tuples
        user_files = []
        for f in files:
            file_dir = os.path.join(subdir, f)
            with open(file_dir, 'r') as inf:
                # Load data into an OrderedDict, to prevent ordering changes
                # and enable reproducibility
                data = json.load(inf, object_pairs_hook=OrderedDict)
            if include_hierarchy:
                user_files.extend([(u, h, ns, f) for (u, h, ns) in
                                   zip(data['users'], data['hierarchies'], data['num_samples'])])
            else:
                user_files.extend([(u, ns, f) for (u, ns) in
                                   zip(data['users'], data['num_samples'])])

        # randomly sample from user_files to pick training set users
        num_users = len(user_files)
        num_train_users = int(frac * num_users)
        indices = [i for i in range(num_users)]
        train_indices = rng.sample(indices, num_train_users)
        train_blist = [False for i in range(num_users)]
        for i in train_indices:
            train_blist[i] = True
        train_user_files = []
        test_user_files = []
        for i in range(num_users):
            if (train_blist[i]):
                train_user_files.append(user_files[i])
            else:
                test_user_files.append(user_files[i])

        max_users = sys.maxsize
        if name == 'femnist':
            max_users = 50  # max number of users per json file
        create_jsons_for(dir, setting_folder, train_user_files, 'train', max_users, include_hierarchy, subdir,
                         arg_label)
        create_jsons_for(dir, setting_folder, test_user_files, 'test', max_users, include_hierarchy, subdir, arg_label)

    else:
        logger.info("splitting data by sample")

        for f in files:
            file_dir = os.path.join(subdir, f)
            with open(file_dir, 'r') as inf:
                # Load data into an OrderedDict, to prevent ordering changes
                # and enable reproducibility
                data = json.load(inf, object_pairs_hook=OrderedDict)

            num_samples_train = []
            user_data_train = {}
            num_samples_test = []
            user_data_test = {}

            user_indices = []  # indices of users in data['users'] that are not deleted

            for i, u in enumerate(data['users']):
                user_data_train[u] = {'x': [], 'y': []}
                user_data_test[u] = {'x': [], 'y': []}

                curr_num_samples = len(data['user_data'][u]['y'])
                if curr_num_samples >= 2:
                    user_indices.append(i)

                    # ensures number of train and test samples both >= 1
                    num_train_samples = max(1, int(frac * curr_num_samples))
                    if curr_num_samples == 2:
                        num_train_samples = 1

                    num_test_samples = curr_num_samples - num_train_samples
                    num_samples_train.append(num_train_samples)
                    num_samples_test.append(num_test_samples)

                    indices = [j for j in range(curr_num_samples)]
                    train_indices = rng.sample(indices, num_train_samples)
                    train_blist = [False for _ in range(curr_num_samples)]
                    for j in train_indices:
                        train_blist[j] = True

                    for j in range(curr_num_samples):
                        if (train_blist[j]):
                            user_data_train[u]['x'].append(data['user_data'][u]['x'][j])
                            user_data_train[u]['y'].append(data['user_data'][u]['y'][j])
                        else:
                            user_data_test[u]['x'].append(data['user_data'][u]['x'][j])
                            user_data_test[u]['y'].append(data['user_data'][u]['y'][j])

            users = [data['users'][i] for i in user_indices]

            all_data_train = {}
            all_data_train['users'] = users
            all_data_train['num_samples'] = num_samples_train
            all_data_train['user_data'] = user_data_train
            all_data_test = {}
            all_data_test['users'] = users
            all_data_test['num_samples'] = num_samples_test
            all_data_test['user_data'] = user_data_test

            if include_hierarchy:
                all_data_train['hierarchies'] = data['hierarchies']
                all_data_test['hierarchies'] = data['hierarchies']

            file_name_train = '{}_train_{}.json'.format((f[:-5]), arg_label)
            file_name_test = '{}_test_{}.json'.format((f[:-5]), arg_label)
            ouf_dir_train = os.path.join(dir, setting_folder, 'train', file_name_train)
            ouf_dir_test = os.path.join(dir, setting_folder, 'test', file_name_test)
            logger.info("writing {}".format(file_name_train))
            with open(ouf_dir_train, 'w') as outfile:
                json.dump(all_data_train, outfile)
            logger.info("writing {}".format(file_name_test))
            with open(ouf_dir_test, 'w') as outfile:
                json.dump(all_data_test, outfile)