Bläddra i källkod

initial upload

Jun Luo 1 år sedan
förälder
incheckning
762712859b
40 ändrade filer med 3269 tillägg och 0 borttagningar
  1. 19 0
      .gitignore copy
  2. 71 0
      dataset/generate_cifar10.py
  3. 76 0
      dataset/generate_cifar100.py
  4. 91 0
      dataset/generate_medmnist.py
  5. 163 0
      dataset/generate_office_home.py
  6. 179 0
      dataset/utils/dataset_utils.py
  7. 4 0
      requirements.txt
  8. 68 0
      system/flcore/clients/clientapfl.py
  9. 35 0
      system/flcore/clients/clientavg.py
  10. 85 0
      system/flcore/clients/clientbabu.py
  11. 132 0
      system/flcore/clients/clientbase.py
  12. 60 0
      system/flcore/clients/clientdyn.py
  13. 151 0
      system/flcore/clients/clientfomo.py
  14. 53 0
      system/flcore/clients/clientlgfedavg.py
  15. 39 0
      system/flcore/clients/clientper.py
  16. 85 0
      system/flcore/clients/clientperavg.py
  17. 67 0
      system/flcore/clients/clientpfedme.py
  18. 88 0
      system/flcore/clients/clientpgfed.py
  19. 67 0
      system/flcore/clients/clientrep.py
  20. 87 0
      system/flcore/clients/clientrod.py
  21. 21 0
      system/flcore/optimizers/fedoptimizer.py
  22. 44 0
      system/flcore/servers/serverapfl.py
  23. 55 0
      system/flcore/servers/serveravg.py
  24. 92 0
      system/flcore/servers/serverbabu.py
  25. 273 0
      system/flcore/servers/serverbase.py
  26. 87 0
      system/flcore/servers/serverdyn.py
  27. 97 0
      system/flcore/servers/serverfomo.py
  28. 74 0
      system/flcore/servers/serverlgfedavg.py
  29. 40 0
      system/flcore/servers/serverlocal.py
  30. 72 0
      system/flcore/servers/serverper.py
  31. 42 0
      system/flcore/servers/serverperavg.py
  32. 56 0
      system/flcore/servers/serverpfedme.py
  33. 129 0
      system/flcore/servers/serverpgfed.py
  34. 90 0
      system/flcore/servers/serverrep.py
  35. 45 0
      system/flcore/servers/serverrod.py
  36. 58 0
      system/flcore/trainmodel/models.py
  37. 212 0
      system/main.py
  38. 48 0
      system/traincifar10_25clt_example.sh
  39. 92 0
      system/utils/data_utils.py
  40. 22 0
      system/utils/tensor_utils.py

+ 19 - 0
.gitignore copy

@@ -0,0 +1,19 @@
+# python
+*__pycache__*
+
+# results
+logs/
+checkpoints/
+results/
+*.h5
+*.log
+
+# dataset
+dataset/cifar10/
+dataset/cifar100/
+dataset/organaminist25/
+dataset/organaminist50/
+dataset/organaminist100/
+dataset/Office-home20/
+dataset/Office-home-raw/
+

+ 71 - 0
dataset/generate_cifar10.py

@@ -0,0 +1,71 @@
+import numpy as np
+import os
+import sys
+import random
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from utils.dataset_utils import check, separate_data, split_data, save_file
+
+
+random.seed(1)
+np.random.seed(1)
+num_clients = 25
+num_classes = 10
+dir_path = "cifar10/"
+
+
+# Allocate data to users
+def generate_cifar10(dir_path, num_clients, num_classes, niid, balance, partition):
+    if not os.path.exists(dir_path):
+        os.makedirs(dir_path)
+        
+    # Setup directory for train/test data
+    config_path = dir_path + "config.json"
+    train_path = dir_path + "train/"
+    test_path = dir_path + "test/"
+
+    if check(config_path, train_path, test_path, num_clients, num_classes, niid, balance, partition):
+        return
+        
+    # Get Cifar10 data
+    transform = transforms.Compose(
+        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
+    trainset = torchvision.datasets.CIFAR10(
+        root=dir_path+"rawdata", train=True, download=True, transform=transform)
+    testset = torchvision.datasets.CIFAR10(
+        root=dir_path+"rawdata", train=False, download=True, transform=transform)
+    trainloader = torch.utils.data.DataLoader(
+        trainset, batch_size=len(trainset.data), shuffle=False)
+    testloader = torch.utils.data.DataLoader(
+        testset, batch_size=len(testset.data), shuffle=False)
+
+    for _, train_data in enumerate(trainloader, 0):
+        trainset.data, trainset.targets = train_data
+    for _, test_data in enumerate(testloader, 0):
+        testset.data, testset.targets = test_data
+
+    dataset_image = []
+    dataset_label = []
+
+    dataset_image.extend(trainset.data.cpu().detach().numpy())
+    dataset_image.extend(testset.data.cpu().detach().numpy())
+    dataset_label.extend(trainset.targets.cpu().detach().numpy())
+    dataset_label.extend(testset.targets.cpu().detach().numpy())
+    dataset_image = np.array(dataset_image)
+    dataset_label = np.array(dataset_label)
+
+    X, y, statistic = separate_data((dataset_image, dataset_label), num_clients, num_classes, 
+                                    niid, balance, partition)
+    train_data, test_data = split_data(X, y)
+    save_file(config_path, train_path, test_path, train_data, test_data, num_clients, num_classes, 
+        statistic, niid, balance, partition)
+
+
+if __name__ == "__main__":
+    niid = True if sys.argv[1] == "noniid" else False
+    balance = True if sys.argv[2] == "balance" else False
+    partition = sys.argv[3] if sys.argv[3] != "-" else None
+
+    generate_cifar10(dir_path, num_clients, num_classes, niid, balance, partition)

+ 76 - 0
dataset/generate_cifar100.py

@@ -0,0 +1,76 @@
+import numpy as np
+import os
+import sys
+import random
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from utils.dataset_utils import check, separate_data, split_data, save_file
+
+
+random.seed(1)
+np.random.seed(1)
+num_clients = 25
+num_classes = 100
+dir_path = "cifar100/"
+
+
+# Allocate data to users
+def generate_cifar100(dir_path, num_clients, num_classes, niid, balance, partition):
+    if not os.path.exists(dir_path):
+        os.makedirs(dir_path)
+        
+    # Setup directory for train/test data
+    config_path = dir_path + "config.json"
+    train_path = dir_path + "train/"
+    test_path = dir_path + "test/"
+
+    if check(config_path, train_path, test_path, num_clients, num_classes, niid, balance, partition):
+        return
+        
+    # Get Cifar100 data
+    transform = transforms.Compose(
+        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
+    trainset = torchvision.datasets.CIFAR100(
+        root=dir_path+"rawdata", train=True, download=True, transform=transform)
+    testset = torchvision.datasets.CIFAR100(
+        root=dir_path+"rawdata", train=False, download=True, transform=transform)
+    trainloader = torch.utils.data.DataLoader(
+        trainset, batch_size=len(trainset.data), shuffle=False)
+    testloader = torch.utils.data.DataLoader(
+        testset, batch_size=len(testset.data), shuffle=False)
+
+    for _, train_data in enumerate(trainloader, 0):
+        trainset.data, trainset.targets = train_data
+    for _, test_data in enumerate(testloader, 0):
+        testset.data, testset.targets = test_data
+
+    dataset_image = []
+    dataset_label = []
+
+    dataset_image.extend(trainset.data.cpu().detach().numpy())
+    dataset_image.extend(testset.data.cpu().detach().numpy())
+    dataset_label.extend(trainset.targets.cpu().detach().numpy())
+    dataset_label.extend(testset.targets.cpu().detach().numpy())
+    dataset_image = np.array(dataset_image)
+    dataset_label = np.array(dataset_label)
+
+    # dataset = []
+    # for i in range(num_classes):
+    #     idx = dataset_label == i
+    #     dataset.append(dataset_image[idx])
+
+    X, y, statistic = separate_data((dataset_image, dataset_label), num_clients, num_classes, 
+                                    niid, balance, partition, class_per_client=20)
+    train_data, test_data = split_data(X, y)
+    save_file(config_path, train_path, test_path, train_data, test_data, num_clients, num_classes, 
+        statistic, niid, balance, partition)
+
+
+if __name__ == "__main__":
+    niid = True if sys.argv[1] == "noniid" else False
+    balance = True if sys.argv[2] == "balance" else False
+    partition = sys.argv[3] if sys.argv[3] != "-" else None
+
+    generate_cifar100(dir_path, num_clients, num_classes, niid, balance, partition)

+ 91 - 0
dataset/generate_medmnist.py

@@ -0,0 +1,91 @@
+import numpy as np
+import os
+import sys
+import random
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from utils.dataset_utils import check, separate_data, split_data, save_file
+from torchvision.datasets import ImageFolder, DatasetFolder
+
+# medmnist
+import medmnist
+from medmnist import INFO
+
+random.seed(1)
+np.random.seed(1)
+num_clients = 25
+dir_path = f"organamnist{num_clients}/"
+
+
+# medmnist
+data_flag = ("".join([i for i in dir_path if i.isalpha()])).lower()
+# data_flag = 'breastmnist'
+download = False
+info = INFO[data_flag]
+task = info['task']
+n_channels = info['n_channels']
+num_classes = len(info['label'])
+DataClass = getattr(medmnist, info['python_class'])
+
+# Allocate data to users
+def generate_dataset(dir_path, num_clients, num_classes, niid, balance, partition):
+    if not os.path.exists(dir_path):
+        os.makedirs(dir_path)
+        
+    # Setup directory for train/test data
+    config_path = dir_path + "config.json"
+    train_path = dir_path + "train/"
+    test_path = dir_path + "test/"
+
+    if check(config_path, train_path, test_path, num_clients, num_classes, niid, balance, partition):
+        return
+
+    # Get data
+    transform = transforms.Compose(
+        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
+
+    trainset = DataClass(split='train', transform=transform, download=download)
+    valset = DataClass(split='val', transform=transform, download=download)
+    testset = DataClass( split='test', transform=transform, download=download)
+    
+    trainloader = torch.utils.data.DataLoader(
+        trainset, batch_size=len(trainset), shuffle=True)
+    valloader = torch.utils.data.DataLoader(
+        valset, batch_size=len(valset), shuffle=True)
+    testloader = torch.utils.data.DataLoader(
+        testset, batch_size=len(testset), shuffle=True)
+
+    for _, train_data in enumerate(trainloader, 0):
+        trainset.data, trainset.targets = train_data
+    for _, val_data in enumerate(valloader, 0):
+        valset.data, valset.targets = val_data
+    for _, test_data in enumerate(testloader, 0):
+        testset.data, testset.targets = test_data
+
+    dataset_image = []
+    dataset_label = []
+
+    dataset_image.extend(trainset.data.cpu().detach().numpy())
+    dataset_image.extend(valset.data.cpu().detach().numpy())
+    dataset_image.extend(testset.data.cpu().detach().numpy())
+    dataset_label.extend(trainset.targets.cpu().detach().numpy())
+    dataset_label.extend(valset.targets.cpu().detach().numpy())
+    dataset_label.extend(testset.targets.cpu().detach().numpy())
+
+    dataset_image = np.array(dataset_image)
+    dataset_label = np.array(dataset_label).reshape(-1)
+
+    X, y, statistic = separate_data((dataset_image, dataset_label), num_clients, num_classes, 
+                                    niid, balance, partition)
+    train_data, test_data = split_data(X, y)
+    save_file(config_path, train_path, test_path, train_data, test_data, num_clients, num_classes, 
+        statistic, niid, balance, partition)
+
+
+if __name__ == "__main__":
+    niid = True if sys.argv[1] == "noniid" else False
+    balance = True if sys.argv[2] == "balance" else False
+    partition = sys.argv[3] if sys.argv[3] != "-" else None
+
+    generate_dataset(dir_path, num_clients, num_classes, niid, balance, partition)

+ 163 - 0
dataset/generate_office_home.py

@@ -0,0 +1,163 @@
+# -*- coding: utf-8 -*-
+# @Author: Jun Luo
+# @Date:   2022-02-28 10:40:03
+# @Last Modified by:   Jun Luo
+# @Last Modified time: 2022-02-28 10:40:03
+
+import numpy as np
+from sklearn.model_selection import train_test_split
+
+import os
+import sys
+import glob
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from PIL import Image
+import shutil
+from torch.utils.data import Dataset, DataLoader
+
+ALPHA = 1.0
+N_CLIENTS = 20
+TEST_PORTION = 0.15
+SEED = 42
+SET_THRESHOLD = 20
+
+IMAGE_SIZE = 32
+N_CLASSES = 65
+
+IMAGE_SRC = "./Office-home-raw/"
+SAVE_FOLDER = f"./Office-home{N_CLIENTS}/"
+
+class ImageDatasetFromFileNames(Dataset):
+    def __init__(self, fns, labels, transform=None, target_transform=None):
+        self.fns = fns
+        self.labels = labels
+        self.transform = transform
+        self.target_transform = target_transform
+
+    def __getitem__(self, index):
+        x = Image.open(self.fns[index])
+        y = self.labels[index]
+        if self.transform is not None:
+            x = self.transform(x)
+        if self.target_transform is not None:
+            y = self.target_transform(y)
+        return x, y
+
+    def __len__(self):
+        return len(self.labels)
+
+def dirichletSplit(alpha=10, n_clients=10, n_classes=10):
+    return np.random.dirichlet(n_clients * [alpha], n_classes)
+
+def isNegligible(partitions, counts, THRESHOLD=2):
+    s = np.matmul(partitions.T, counts)
+    return (s < THRESHOLD).any()
+
+def split2clientsofficehome(x_fns, ys, stats, partitions, client_idx_offset=0, verbose=False):
+    print("==> splitting dataset into clients' own datasets")
+    n_classes, n_clients = partitions.shape
+    splits = [] # n_classes * n_clients
+    for i in range(n_classes):
+        indices = np.where(ys == i)[0]
+        np.random.shuffle(indices)
+        cuts = np.cumsum(np.round_(partitions[i] * stats[str(i)]).astype(int))
+        cuts = np.clip(cuts, 0, stats[str(i)])
+        cuts[-1] = stats[str(i)]
+        splits.append(np.split(indices, cuts))
+        
+    clients = []
+    for i in range(n_clients):
+        indices = np.concatenate([splits[j][i] for j in range(n_classes)], axis=0)
+        dset = [x_fns[indices], ys[indices]]
+        clients.append(dset)
+        if verbose:
+            print("\tclient %03d has" % (client_idx_offset+i+1), len(dset[0]), "images")
+    return clients
+
+def get_immediate_subdirectories(a_dir):
+    return [name for name in os.listdir(a_dir)
+            if os.path.isdir(os.path.join(a_dir, name))]
+
+if __name__ == "__main__":
+    np.random.seed(SEED)
+    styles = ["Art", "Clipart", "Product", "Real World"]
+    assert N_CLIENTS % 4 == 0, "### For Office-Home dataset, N_CLIENTS must be a multiple of 4...\nPlease change N_CLIENTS..."
+    N_CLIENTS_PER_STYLE = N_CLIENTS // len(styles)
+    
+    cls_names = []
+    for fn in get_immediate_subdirectories(IMAGE_SRC + styles[0]):
+        cls_names.append(os.path.split(fn)[1])
+    idx2clsname = {i: name for i, name in enumerate(cls_names)}
+    get_cls_folder = lambda style, cls_n: os.path.join(IMAGE_SRC, style, cls_n)
+
+    def get_dataset(dir, style):
+        x_fns = []
+        ys = []
+        stats_dict = {}
+        stats_list = []
+        for i in range(N_CLASSES):
+            cls_name = idx2clsname[i]
+            x_for_cls = list(glob.glob(os.path.join(dir, style, cls_name, "*.jpg")))
+            x_fns += x_for_cls
+            ys += [i for _ in range(len(x_for_cls))]
+            stats_dict[str(i)] = len(x_for_cls)
+            stats_list.append(len(x_for_cls))
+        return np.array(x_fns), np.array(ys), stats_dict, np.array(stats_list)
+
+    clients = []
+    for style_idx, style in enumerate(styles):
+        dataset_style_fns, dataset_style_labels, dataset_stats_dict, dataset_stats_list = get_dataset(IMAGE_SRC, style)
+        # print(len(dataset_style_fns), len(dataset_style_labels), np.sum(list(dataset_stats.values())))
+        partitions = np.zeros((N_CLASSES, N_CLIENTS_PER_STYLE))
+        i = 0
+        while isNegligible(partitions, dataset_stats_list, SET_THRESHOLD/TEST_PORTION):
+            partitions = dirichletSplit(alpha=ALPHA, n_clients=N_CLIENTS_PER_STYLE, n_classes=N_CLASSES)
+            i += 1
+            print(f"==> partitioning for the {i}th time (client dataset size >= {SET_THRESHOLD})")
+        clients += split2clientsofficehome(dataset_style_fns,
+                                           dataset_style_labels,
+                                           dataset_stats_dict,
+                                           partitions,
+                                           client_idx_offset=style_idx*N_CLIENTS_PER_STYLE,
+                                           verbose=True)
+    # print()
+    # print(np.sum([len(c[0]) for c in clients]))
+    transform = transforms.Compose(
+        [transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+    
+    if not os.path.exists(f"{SAVE_FOLDER}train/"):
+        os.makedirs(f"{SAVE_FOLDER}train/")
+    if not os.path.exists(f"{SAVE_FOLDER}test/"):
+        os.makedirs(f"{SAVE_FOLDER}test/")
+
+    for client_idx, (clt_x_fns, clt_ys) in enumerate(clients):
+        print("==> saving (to %s) for client [%3d/%3d]" % (SAVE_FOLDER, client_idx+1, N_CLIENTS))
+        # split train, val, test
+        try:
+            X_train_fns, X_test_fns, y_train, y_test = train_test_split(
+                clt_x_fns, clt_ys, test_size=TEST_PORTION, random_state=SEED, stratify=clt_ys)
+        except ValueError:
+            X_train_fns, X_test_fns, y_train, y_test = train_test_split(
+                clt_x_fns, clt_ys, test_size=TEST_PORTION, random_state=SEED)
+
+        trainset = ImageDatasetFromFileNames(X_train_fns, y_train, transform=transform)
+        testset = ImageDatasetFromFileNames(X_test_fns, y_test, transform=transform)
+
+        trainloader = torch.utils.data.DataLoader(
+            trainset, batch_size=len(trainset), shuffle=False)
+        testloader = torch.utils.data.DataLoader(
+            testset, batch_size=len(testset), shuffle=False)
+
+        xs_train, ys_train = next(iter(trainloader))
+        xs_test, ys_test = next(iter(testloader))
+        train_dict = {"x": xs_train.numpy(), "y": ys_train.numpy()}
+        test_dict = {"x": xs_test.numpy(), "y": ys_test.numpy()}
+
+        # save
+        for data_dict, npz_fn in [(train_dict, SAVE_FOLDER+f"train/{client_idx}.npz"), (test_dict, SAVE_FOLDER+f"test/{client_idx}.npz")]:
+            with open(npz_fn, "wb") as f:
+                np.savez_compressed(f, data=data_dict)
+        
+    print("\n==> finished saving all npz images.")

+ 179 - 0
dataset/utils/dataset_utils.py

@@ -0,0 +1,179 @@
+import os
+import ujson
+import numpy as np
+import gc
+from sklearn.model_selection import train_test_split
+
+batch_size = 10
+train_size = 0.75 # merge original training set and test set, then split it manually. 
+least_samples = batch_size / (1-train_size) # least samples for each client
+alpha = 0.3 # for Dirichlet distribution
+
+def check(config_path, train_path, test_path, num_clients, num_classes, niid=False, 
+        balance=True, partition=None):
+    # check existing dataset
+    if os.path.exists(config_path):
+        with open(config_path, 'r') as f:
+            config = ujson.load(f)
+        if config['num_clients'] == num_clients and \
+            config['num_classes'] == num_classes and \
+            config['non_iid'] == niid and \
+            config['balance'] == balance and \
+            config['partition'] == partition and \
+            config['alpha'] == alpha and \
+            config['batch_size'] == batch_size:
+            print("\nDataset already generated.\n")
+            return True
+
+    dir_path = os.path.dirname(train_path)
+    if not os.path.exists(dir_path):
+        os.makedirs(dir_path)
+    dir_path = os.path.dirname(test_path)
+    if not os.path.exists(dir_path):
+        os.makedirs(dir_path)
+
+    return False
+
+def separate_data(data, num_clients, num_classes, niid=False, balance=False, partition=None, class_per_client=2):
+    X = [[] for _ in range(num_clients)]
+    y = [[] for _ in range(num_clients)]
+    statistic = [[] for _ in range(num_clients)]
+
+    dataset_content, dataset_label = data
+
+    dataidx_map = {}
+
+    if not niid:
+        partition = 'pat'
+        class_per_client = num_classes
+
+    if partition == 'pat':
+        idxs = np.array(range(len(dataset_label)))
+        idx_for_each_class = []
+        for i in range(num_classes):
+            idx_for_each_class.append(idxs[dataset_label == i])
+
+        class_num_per_client = [class_per_client for _ in range(num_clients)]
+        for i in range(num_classes):
+            selected_clients = []
+            for client in range(num_clients):
+                if class_num_per_client[client] > 0:
+                    selected_clients.append(client)
+                selected_clients = selected_clients[:int(num_clients/num_classes*class_per_client)]
+
+            num_all_samples = len(idx_for_each_class[i])
+            num_selected_clients = len(selected_clients)
+            num_per = num_all_samples / num_selected_clients
+            if balance:
+                num_samples = [int(num_per) for _ in range(num_selected_clients-1)]
+            else:
+                num_samples = np.random.randint(max(num_per/10, least_samples/num_classes), num_per, num_selected_clients-1).tolist()
+            num_samples.append(num_all_samples-sum(num_samples))
+
+            idx = 0
+            for client, num_sample in zip(selected_clients, num_samples):
+                if client not in dataidx_map.keys():
+                    dataidx_map[client] = idx_for_each_class[i][idx:idx+num_sample]
+                else:
+                    dataidx_map[client] = np.append(dataidx_map[client], idx_for_each_class[i][idx:idx+num_sample], axis=0)
+                idx += num_sample
+                class_num_per_client[client] -= 1
+
+    elif partition == "dir":
+        # https://github.com/IBM/probabilistic-federated-neural-matching/blob/master/experiment.py
+        min_size = 0
+        K = num_classes
+        N = len(dataset_label)
+
+        while min_size < least_samples:
+            idx_batch = [[] for _ in range(num_clients)]
+            for k in range(K):
+                idx_k = np.where(dataset_label == k)[0]
+                np.random.shuffle(idx_k)
+                proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
+                proportions = np.array([p*(len(idx_j)<N/num_clients) for p,idx_j in zip(proportions,idx_batch)])
+                proportions = proportions/proportions.sum()
+                proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
+                idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
+                min_size = min([len(idx_j) for idx_j in idx_batch])
+
+        for j in range(num_clients):
+            dataidx_map[j] = idx_batch[j]
+    else:
+        raise NotImplementedError
+
+    # assign data
+    for client in range(num_clients):
+        idxs = dataidx_map[client]
+        X[client] = dataset_content[idxs]
+        y[client] = dataset_label[idxs]
+
+        for i in np.unique(y[client]):
+            statistic[client].append((int(i), int(sum(y[client]==i))))
+            
+
+    del data
+    # gc.collect()
+
+    for client in range(num_clients):
+        print(f"Client {client}\t Size of data: {len(X[client])}\t Labels: ", np.unique(y[client]))
+        print(f"\t\t Samples of labels: ", [i for i in statistic[client]])
+        print("-" * 50)
+
+    return X, y, statistic
+
+
+def split_data(X, y):
+    # Split dataset
+    train_data, test_data = [], []
+    num_samples = {'train':[], 'test':[]}
+
+    for i in range(len(y)):
+        unique, count = np.unique(y[i], return_counts=True)
+        if min(count) > 1:
+            X_train, X_test, y_train, y_test = train_test_split(
+                X[i], y[i], train_size=train_size, shuffle=True)
+        else:
+            X_train, X_test, y_train, y_test = train_test_split(
+                X[i], y[i], train_size=train_size, shuffle=True)
+
+        train_data.append({'x': X_train, 'y': y_train})
+        num_samples['train'].append(len(y_train))
+        test_data.append({'x': X_test, 'y': y_test})
+        num_samples['test'].append(len(y_test))
+
+    print("Total number of samples:", sum(num_samples['train'] + num_samples['test']))
+    print("The number of train samples:", num_samples['train'])
+    print("The number of test samples:", num_samples['test'])
+    print()
+    del X, y
+    # gc.collect()
+
+    return train_data, test_data
+
+def save_file(config_path, train_path, test_path, train_data, test_data, num_clients, 
+                num_classes, statistic, niid=False, balance=True, partition=None):
+    config = {
+        'num_clients': num_clients, 
+        'num_classes': num_classes, 
+        'non_iid': niid, 
+        'balance': balance, 
+        'partition': partition, 
+        'Size of samples for labels in clients': statistic, 
+        'alpha': alpha, 
+        'batch_size': batch_size, 
+    }
+
+    # gc.collect()
+    print("Saving to disk.\n")
+
+    for idx, train_dict in enumerate(train_data):
+        with open(train_path + str(idx) + '.npz', 'wb') as f:
+            np.savez_compressed(f, data=train_dict)
+    for idx, test_dict in enumerate(test_data):
+        with open(test_path + str(idx) + '.npz', 'wb') as f:
+            np.savez_compressed(f, data=test_dict)
+    with open(config_path, 'w') as f:
+        ujson.dump(config, f)
+
+    print("Finish generating dataset.\n")

+ 4 - 0
requirements.txt

@@ -0,0 +1,4 @@
+scikit_learn==1.1.1
+torch==1.4.0
+torchvision==0.5.0
+medmnist==2.1.0

+ 68 - 0
system/flcore/clients/clientapfl.py

@@ -0,0 +1,68 @@
+import copy
+import torch
+import torch.nn as nn
+import numpy as np
+from flcore.clients.clientbase import Client
+
+class clientAPFL(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
+
+        self.alpha = args.alpha
+        self.model_local = copy.deepcopy(self.model)
+        self.optimizer_local = torch.optim.SGD(self.model_local.parameters(), lr=self.learning_rate)
+        self.model_per = copy.deepcopy(self.model)
+        self.optimizer_per = torch.optim.SGD(self.model_per.parameters(), lr=self.learning_rate)
+
+    def set_parameters(self, model):
+        for new_param, old_param, param_l, param_p in zip(model.parameters(), self.model.parameters(),
+            self.model_local.parameters(), self.model_per.parameters()):
+            old_param.data = new_param.data.clone()
+            param_p.data = self.alpha * param_l.data + (1 - self.alpha) * new_param.data
+
+    def train(self):
+        trainloader = self.load_train_data()
+        self.model.train()
+
+        max_local_steps = self.local_steps
+
+        # self.model_per: personalized model (v_bar), self.model: global_model (w)
+        for step in range(max_local_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+
+                # update global model (self.model)
+                self.optimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+                self.optimizer.step()
+
+                # update local model (self.model_local) grad_(v_bar) = 
+                self.optimizer_per.zero_grad()
+                output_per = self.model_per(x)
+                loss_per = self.criterion(output_per, y)
+                loss_per.backward() # update (by gradient) model_local before updating model_per (by interpolation)
+
+                # update model_local by gradient (gradient is alpha * grad(model_per))
+                # see https://github.com/lgcollins/FedRep/blob/main/models/Update.py#L410 and the algorithm in paper
+                self.optimizer_local.zero_grad()
+                for p_l, p_p in zip(self.model_local.parameters(), self.model_per.parameters()):
+                    if p_l.grad is None:
+                        p_l.grad = self.alpha * p_p.grad.data.clone()
+                    else:
+                        p_l.grad.data = self.alpha * p_p.grad.data.clone()
+                self.optimizer_local.step()
+                
+                # update model_per by interpolation
+                for p_p, p_g, p_l in zip(self.model_per.parameters(), self.model.parameters(), self.model_local.parameters()):
+                    p_p.data = self.alpha * p_l.data + (1 - self.alpha) * p_g.data
+
+

+ 35 - 0
system/flcore/clients/clientavg.py

@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import copy
+import sys
+from flcore.clients.clientbase import Client
+
+
+class clientAVG(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
+
+
+    def train(self):
+        trainloader = self.load_train_data()
+        self.model.train()
+
+        max_local_steps = self.local_steps
+
+        for step in range(max_local_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.optimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+                self.optimizer.step()
+

+ 85 - 0
system/flcore/clients/clientbabu.py

@@ -0,0 +1,85 @@
+import copy
+import torch
+import torch.nn as nn
+import numpy as np
+from flcore.clients.clientbase import Client
+
+class clientBABU(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        self.criterion = nn.CrossEntropyLoss()
+
+        self.fine_tuning_steps = args.fine_tuning_steps
+        self.alpha = args.alpha # fine-tuning's learning rate
+
+        for param in self.model.predictor.parameters():
+            param.requires_grad = False
+
+    def train_one_iter(self, x, y, optimizer):
+        optimizer.zero_grad()
+        output = self.model(x)
+        loss = self.criterion(output, y)
+        loss.backward()
+        optimizer.step()
+
+    def get_training_optimizer(self, **kwargs):
+        return torch.optim.SGD(self.model.base.parameters(), lr=self.learning_rate, momentum=0.9)
+
+    def get_fine_tuning_optimizer(self, **kwargs):
+        return torch.optim.SGD(self.model.parameters(), lr=self.alpha, momentum=0.9)
+        
+    def prepare_training(self, **kwargs):
+        pass
+
+    def prepare_fine_tuning(self, **kwargs):
+        pass
+
+    def train(self):
+        trainloader = self.load_train_data()
+        # self.model.to(self.device)
+        self.model.train()
+        optimizer = self.get_training_optimizer()
+        self.prepare_training() # prepare_training after getting optimizer
+
+        max_local_steps = self.local_steps
+
+        for step in range(max_local_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.train_one_iter(x, y, optimizer)
+
+        # self.model.cpu()
+    def set_parameters(self, base):
+        for new_param, old_param in zip(base.parameters(), self.model.base.parameters()):
+            old_param.data = new_param.data.clone()
+
+    def set_fine_tune_parameters(self, model):
+        for new_param, old_param in zip(model.parameters(), self.model.parameters()):
+            old_param.data = new_param.data.clone()
+
+    def fine_tune(self, which_module=['base', 'predictor']):
+        trainloader = self.load_train_data()
+        self.model.train()
+        self.prepare_fine_tuning() # prepare_fine_tuning before getting optimizer
+        optimizer = self.get_fine_tuning_optimizer()
+
+        if 'predictor' in which_module:
+            for param in self.model.predictor.parameters():
+                param.requires_grad = True
+
+        if 'base' not in which_module:
+            for param in self.model.predictor.parameters():
+                param.requires_grad = False
+            
+        for step in range(self.fine_tuning_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.train_one_iter(x, y, optimizer)

+ 132 - 0
system/flcore/clients/clientbase.py

@@ -0,0 +1,132 @@
+import copy
+import torch
+import torch.nn as nn
+import numpy as np
+import os
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from sklearn.preprocessing import label_binarize
+from sklearn import metrics
+from utils.data_utils import read_client_data
+
+
+class Client(object):
+    """
+    Base class for clients in federated learning.
+    """
+
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        self.model = copy.deepcopy(args.model)
+        self.dataset = args.dataset
+        self.device = args.device
+        self.id = id  # integer
+
+        self.num_classes = args.num_classes
+        self.train_samples = train_samples
+        self.test_samples = test_samples
+        self.batch_size = args.batch_size
+        self.learning_rate = args.local_learning_rate
+        self.local_steps = args.local_steps
+
+        # check BatchNorm
+        self.has_BatchNorm = False
+        for layer in self.model.children():
+            if isinstance(layer, nn.BatchNorm2d):
+                self.has_BatchNorm = True
+                break
+            
+        self.sample_rate = self.batch_size / self.train_samples
+
+
+    def load_train_data(self, batch_size=None):
+        if batch_size == None:
+            batch_size = self.batch_size
+        train_data = read_client_data(self.dataset, self.id, is_train=True)
+        batch_size = min(batch_size, len(train_data))
+        return DataLoader(train_data, batch_size, drop_last=True, shuffle=True)
+
+    def load_test_data(self, batch_size=None):
+        if batch_size == None:
+            batch_size = self.batch_size
+        test_data = read_client_data(self.dataset, self.id, is_train=False)
+        batch_size = min(batch_size, len(test_data))
+        return DataLoader(test_data, batch_size, drop_last=False, shuffle=True)
+        
+    def set_parameters(self, model):
+        for new_param, old_param in zip(model.parameters(), self.model.parameters()):
+            old_param.data = new_param.data.clone()
+
+    def clone_model(self, model, target):
+        for param, target_param in zip(model.parameters(), target.parameters()):
+            target_param.data = param.data.clone()
+            # target_param.grad = param.grad.clone()
+
+    def update_parameters(self, model, new_params):
+        for param, new_param in zip(model.parameters(), new_params):
+            param.data = new_param.data.clone()
+
+
+    def get_eval_model(self, temp_model=None):
+        model = self.model_per if hasattr(self, "model_per") else self.model
+        return model
+
+    def standard_train(self):
+        trainloader = self.load_train_data()
+        self.model.train()
+        for p in self.model.parameters():
+            p.requires_grad = True
+        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
+        # 1 epoch
+        for i, (x, y) in enumerate(trainloader):
+            if type(x) == type([]):
+                x[0] = x[0].to(self.device)
+            else:
+                x = x.to(self.device)
+            y = y.to(self.device)
+            optimizer.zero_grad()
+            output = self.model(x)
+            loss = self.criterion(output, y)
+            loss.backward()
+            optimizer.step()
+
+    def test_metrics(self, temp_model=None):
+        testloaderfull = self.load_test_data()
+        # self.model = self.load_model('model')
+        # self.model.to(self.device)
+        model = self.get_eval_model(temp_model)
+        model.eval()
+        
+        test_correct = 0
+        test_num = 0
+        test_loss = 0.0
+        y_prob = []
+        y_true = []
+        
+        with torch.no_grad():
+            for x, y in testloaderfull:
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                output = model(x)
+                test_loss += (self.criterion(output, y.long()) * y.shape[0]).item() # sum up batch loss
+
+                test_correct += (torch.sum(torch.argmax(output, dim=1) == y)).item()
+                test_num += y.shape[0]
+
+                y_prob.append(output.detach().cpu().numpy())
+                y_true.append(label_binarize(y.detach().cpu().numpy(), classes=np.arange(self.num_classes)))
+
+        # self.model.cpu()
+        # self.save_model(self.model, 'model')
+
+        y_prob = np.concatenate(y_prob, axis=0)
+        y_true = np.concatenate(y_true, axis=0)
+        try:
+            test_auc = metrics.roc_auc_score(y_true, y_prob, average='micro')
+            test_loss /= test_num
+        except ValueError:
+            test_auc, test_loss = 0.0, 0.0
+        test_acc = test_correct / test_num
+        return test_acc, test_auc, test_loss, test_num

+ 60 - 0
system/flcore/clients/clientdyn.py

@@ -0,0 +1,60 @@
+import copy
+import torch
+import torch.nn as nn
+import numpy as np
+from flcore.clients.clientbase import Client
+from utils.tensor_utils import l2_squared_diff, model_dot_product
+
+class clientDyn(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
+
+        self.alpha = args.alpha
+
+        self.global_model_vector = None
+        self.old_grad = copy.deepcopy(self.model)
+        for p in self.old_grad.parameters():
+            p.requires_grad = False
+            p.data.zero_()
+        
+    def train(self):
+        trainloader = self.load_train_data()
+
+        # self.model.to(self.device)
+        self.model.train()
+
+        max_local_steps = self.local_steps
+
+        for step in range(max_local_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.optimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+
+                if self.untrained_global_model != None:
+                    loss += self.alpha/2 * l2_squared_diff(self.model, self.untrained_global_model)
+                    loss -= model_dot_product(self.model, self.old_grad)
+
+                loss.backward()
+                self.optimizer.step()
+
+        if self.untrained_global_model != None:
+            for p_old_grad, p_cur, p_broadcast in zip(self.old_grad.parameters(), self.model.parameters(), self.untrained_global_model.parameters()):
+                p_old_grad.data -= self.alpha * (p_cur.data - p_broadcast.data)
+
+        # self.model.cpu()
+            
+    def set_parameters(self, model):
+        for new_param, old_param in zip(model.parameters(), self.model.parameters()):
+            old_param.data = new_param.data.clone()
+        self.untrained_global_model = copy.deepcopy(model)
+        for p in self.untrained_global_model.parameters():
+            p.requires_grad = False

+ 151 - 0
system/flcore/clients/clientfomo.py

@@ -0,0 +1,151 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import copy
+from flcore.clients.clientbase import Client
+from torch.utils.data import DataLoader
+from utils.data_utils import read_client_data
+
+
+class clientFomo(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        
+        self.num_clients = args.num_clients
+        self.old_model = copy.deepcopy(self.model)
+        self.received_ids = []
+        self.received_models = []
+        self.weight_vector = torch.zeros(self.num_clients, device=self.device)
+
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
+
+        self.val_ratio = 0.2
+        self.train_samples = self.train_samples * (1-self.val_ratio)
+
+
+    def train(self):
+        trainloader, val_loader = self.load_train_data()
+
+        self.aggregate_parameters(val_loader)
+        self.clone_model(self.model, self.old_model)
+
+        # self.model.to(self.device)
+        self.model.train()
+        
+        max_local_steps = self.local_steps
+
+        for step in range(max_local_steps):
+            for x, y in trainloader:
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.optimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+                self.optimizer.step()
+
+        # self.model.cpu()
+
+
+    def standard_train(self):
+        trainloader, val_loader = self.load_train_data()
+        self.model.train()
+        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
+        # 1 epoch
+        for i, (x, y) in enumerate(trainloader):
+            if type(x) == type([]):
+                x[0] = x[0].to(self.device)
+            else:
+                x = x.to(self.device)
+            y = y.to(self.device)
+            optimizer.zero_grad()
+            output = self.model(x)
+            loss = self.criterion(output, y)
+            loss.backward()
+            optimizer.step()
+
+    def load_train_data(self, batch_size=None):
+        if batch_size == None:
+            batch_size = self.batch_size
+        train_data = read_client_data(self.dataset, self.id, is_train=True)
+        val_idx = -int(self.val_ratio*len(train_data))
+        val_data = train_data[val_idx:]
+        train_data = train_data[:val_idx]
+
+        trainloader = DataLoader(train_data, self.batch_size, drop_last=True, shuffle=True)
+        val_loader = DataLoader(val_data, self.batch_size, drop_last=self.has_BatchNorm, shuffle=True)
+
+        return trainloader, val_loader
+    
+    def receive_models(self, ids, models):
+        self.received_ids = ids
+        self.received_models = models
+
+    def weight_cal(self, val_loader):
+        weight_list = []
+        L = self.recalculate_loss(self.old_model, val_loader)
+        for received_model in self.received_models:
+            params_dif = []
+            for param_n, param_i in zip(received_model.parameters(), self.old_model.parameters()):
+                params_dif.append((param_n - param_i).view(-1))
+            params_dif = torch.cat(params_dif)
+
+            d = L - self.recalculate_loss(received_model, val_loader)
+            if d > 0:
+                weight_list.append((d / (torch.norm(params_dif) + 1e-5)).item())
+            else:
+                weight_list.append(0.0)
+
+        if len(weight_list) != 0:
+            weight_list = np.array(weight_list)
+            weight_list /= (np.sum(weight_list) + 1e-10)
+
+        self.weight_vector_update(weight_list)
+        return torch.tensor(weight_list)
+        
+    def weight_vector_update(self, weight_list):
+        self.weight_vector = np.zeros(self.num_clients)
+        for w, id in zip(weight_list, self.received_ids):
+            self.weight_vector[id] += w.item()
+        self.weight_vector = torch.tensor(self.weight_vector).to(self.device)
+
+    def recalculate_loss(self, new_model, val_loader):
+        L = 0
+        for x, y in val_loader:
+            if type(x) == type([]):
+                x[0] = x[0].to(self.device)
+            else:
+                x = x.to(self.device)
+            y = y.to(self.device)
+            output = new_model(x)
+            loss = self.criterion(output, y)
+            L += (loss * y.shape[0]).item()
+        return L / len(val_loader.dataset)
+
+
+    def add_parameters(self, w, received_model):
+        for param, received_param in zip(self.model.parameters(), received_model.parameters()):
+            param.data += received_param.data.clone() * w
+        
+    def aggregate_parameters(self, val_loader):
+        weights = self.weight_cal(val_loader)
+
+        if len(weights) > 0 and sum(weights) > 0.0:
+            for param in self.model.parameters():
+                param.data.zero_()
+
+            for w, received_model in zip(weights, self.received_models):
+                self.add_parameters(w, received_model)
+
+    def weight_scale(self, weights):
+        weights = torch.maximum(weights, torch.tensor(0))
+        w_sum = torch.sum(weights)
+        if w_sum > 0:
+            weights = [w/w_sum for w in weights]
+            return torch.tensor(weights)
+        else:
+            return torch.tensor([])

+ 53 - 0
system/flcore/clients/clientlgfedavg.py

@@ -0,0 +1,53 @@
+import copy
+import torch
+import torch.nn as nn
+import numpy as np
+from flcore.clients.clientbase import Client
+
+class clientLGFedAvg(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
+
+    def train(self):
+        trainloader = self.load_train_data()
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
+        # self.model.to(self.device)
+        self.model.train()
+
+        max_local_steps = self.local_steps
+
+        for step in range(max_local_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                for param in self.model.base.parameters():
+                    param.requires_grad = True
+                for param in self.model.predictor.parameters():
+                    param.requires_grad = False
+                self.optimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+                self.optimizer.step()
+
+                for param in self.model.base.parameters():
+                    param.requires_grad = False
+                for param in self.model.predictor.parameters():
+                    param.requires_grad = True
+                self.optimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+                self.optimizer.step()
+
+        # self.model.cpu()
+
+    def set_parameters(self, model):
+        for new_param, old_param in zip(model.parameters(), self.model.predictor.parameters()):
+            old_param.data = new_param.data.clone()

+ 39 - 0
system/flcore/clients/clientper.py

@@ -0,0 +1,39 @@
+import copy
+import torch
+import torch.nn as nn
+import numpy as np
+from flcore.clients.clientbase import Client
+
+class clientPer(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
+
+    def train(self):
+        trainloader = self.load_train_data()
+
+        # self.model.to(self.device)
+        self.model.train()
+
+        max_local_steps = self.local_steps
+
+        for step in range(max_local_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.optimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+                self.optimizer.step()
+
+        # self.model.cpu()
+
+    def set_parameters(self, model):
+        for new_param, old_param in zip(model.parameters(), self.model.base.parameters()):
+            old_param.data = new_param.data.clone()

+ 85 - 0
system/flcore/clients/clientperavg.py

@@ -0,0 +1,85 @@
+import numpy as np
+from sklearn.preprocessing import label_binarize
+from sklearn import metrics
+import torch
+import copy
+import torch.nn as nn
+from flcore.clients.clientbase import Client
+
+class clientPerAvg(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        self.alpha = args.alpha
+        self.beta = args.beta
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer1 = torch.optim.SGD(self.model.parameters(), lr=self.alpha)
+        self.optimizer2 = torch.optim.SGD(self.model.parameters(), lr=self.beta)
+
+    def train(self):
+        trainloader = self.load_train_data(self.batch_size*2)
+        self.model.train()
+
+        max_local_steps = self.local_steps
+
+        for step in range(max_local_steps):  # local update
+            for X, Y in trainloader:
+                temp_model = copy.deepcopy(list(self.model.parameters()))
+
+                # step 1
+                if type(X) == type([]):
+                    x = [None, None]
+                    x[0] = X[0][:self.batch_size].to(self.device)
+                    x[1] = X[1][:self.batch_size]
+                else:
+                    x = X[:self.batch_size].to(self.device)
+                y = Y[:self.batch_size].to(self.device)
+                self.optimizer1.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+                self.optimizer1.step()
+
+                # step 2
+                if type(X) == type([]):
+                    x = [None, None]
+                    x[0] = X[0][self.batch_size:].to(self.device)
+                    x[1] = X[1][self.batch_size:]
+                else:
+                    x = X[self.batch_size:].to(self.device)
+                y = Y[self.batch_size:].to(self.device)
+                self.optimizer2.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+
+                # restore the model parameters to the one before first update
+                for old_param, new_param in zip(self.model.parameters(), temp_model):
+                    old_param.data = new_param.data.clone()
+
+                self.optimizer2.step()
+
+        # self.model.cpu()
+
+    def train_one_step(self):
+        trainloader = self.load_train_data(self.batch_size)
+        iter_trainloader = iter(trainloader)
+        self.model.train()
+        (x, y) = next(iter_trainloader)
+        if type(x) == type([]):
+            x[0] = x[0].to(self.device)
+        else:
+            x = x.to(self.device)
+        y = y.to(self.device)
+        self.optimizer2.zero_grad()
+        output = self.model(x)
+        loss = self.criterion(output, y)
+        loss.backward()
+        self.optimizer2.step()
+
+    # comment for testing on new clients
+    def test_metrics(self, temp_model=None):
+        temp_model = copy.deepcopy(self.model)
+        self.train_one_step()
+        return_val = super().test_metrics(temp_model)
+        self.clone_model(temp_model, self.model)
+        return return_val

+ 67 - 0
system/flcore/clients/clientpfedme.py

@@ -0,0 +1,67 @@
+import numpy as np
+import copy
+import torch
+import torch.nn as nn
+from flcore.optimizers.fedoptimizer import pFedMeOptimizer
+from flcore.clients.clientbase import Client
+
+
+class clientpFedMe(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+
+        self.lambdaa = args.lambdaa
+        self.K = args.K
+        self.personalized_learning_rate = args.p_learning_rate
+
+        # these parameters are for personalized federated learing.
+        self.local_params = copy.deepcopy(list(self.model.parameters()))
+        self.personalized_params = copy.deepcopy(list(self.model.parameters()))
+        self.criterion = nn.CrossEntropyLoss()
+
+    def train(self):
+        trainloader = self.load_train_data()
+        self.model.train()
+
+        max_local_steps = self.local_steps
+
+        self.optimizer = pFedMeOptimizer(self.model.parameters(),
+                                         local_model=self.local_params,
+                                         lr=self.personalized_learning_rate,
+                                         lambdaa=self.lambdaa)
+        for step in range(max_local_steps):  # local update
+            for x, y in trainloader:
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+
+                # K is number of personalized steps
+                for i in range(self.K):
+                    self.optimizer.zero_grad()
+                    output = self.model(x)
+                    loss = self.criterion(output, y)
+                    loss.backward()
+                    # finding aproximate theta
+                    self.personalized_params = self.optimizer.step()
+
+                # update local weight after finding aproximate theta
+                for new_param, localweight in zip(self.personalized_params, self.local_params):
+                    localweight = localweight.to(self.device)
+                    localweight.data = localweight.data - self.lambdaa * self.learning_rate * (localweight.data - new_param.data)
+
+        # self.model.cpu()
+
+        self.update_parameters(self.model, self.local_params)
+
+
+    # comment for testing on new clients
+    def get_eval_model(self, temp_model=None):
+        self.update_parameters(self.model, self.personalized_params)
+        return self.model
+
+    def set_parameters(self, model):
+        for new_param, old_param, local_param in zip(model.parameters(), self.model.parameters(), self.local_params):
+            old_param.data = new_param.data.clone()
+            local_param.data = new_param.data.clone()

+ 88 - 0
system/flcore/clients/clientpgfed.py

@@ -0,0 +1,88 @@
+import torch
+import numpy as np
+import copy
+import torch.nn as nn
+from flcore.clients.clientbase import Client
+from utils.tensor_utils import model_dot_product
+
+class clientPGFed(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
+        self.lambdaa = args.lambdaa # /ita_2 in paper, learning rate for a_i
+        self.latest_grad = copy.deepcopy(self.model)
+        self.prev_loss_minuses = {}
+        self.prev_mean_grad = None
+        self.prev_convex_comb_grad = None
+        self.a_i = None
+
+    def train(self):
+        trainloader = self.load_train_data()
+        self.model.train()
+        max_local_steps = self.local_steps
+
+        for step in range(max_local_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.optimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+
+                if self.prev_convex_comb_grad is not None:
+                    for p_m, p_prev_conv in zip(self.model.parameters(), self.prev_convex_comb_grad.parameters()):
+                        p_m.grad.data += p_prev_conv.data
+                    dot_prod = model_dot_product(self.model, self.prev_mean_grad, requires_grad=False)
+                    self.update_a_i(dot_prod)
+                self.optimizer.step()
+        
+        # get loss_minus and latest_grad
+        self.loss_minus = 0.0
+        test_num = 0
+        self.optimizer.zero_grad()
+        for i, (x, y) in enumerate(trainloader):
+            if type(x) == type([]):
+                x[0] = x[0].to(self.device)
+            else:
+                x = x.to(self.device)
+            y = y.to(self.device)
+            test_num += y.shape[0]
+            output = self.model(x)
+            loss = self.criterion(output, y)
+            self.loss_minus += (loss * y.shape[0]).item()
+            loss.backward()
+
+        self.loss_minus /= test_num
+        for p_l, p in zip(self.latest_grad.parameters(), self.model.parameters()):
+            p_l.data = p.grad.data.clone() / len(trainloader)
+        self.loss_minus -= model_dot_product(self.latest_grad, self.model, requires_grad=False)
+
+    def get_eval_model(self, temp_model=None):
+        model = self.model if temp_model is None else temp_model
+        return model
+
+    def update_a_i(self, dot_prod):
+        for clt_j, mu_loss_minus in self.prev_loss_minuses.items():
+            self.a_i[clt_j] -= self.lambdaa * (mu_loss_minus + dot_prod)
+            self.a_i[clt_j] = max(self.a_i[clt_j], 0.0)
+    
+    def set_model(self, old_m, new_m, momentum=0.0):
+        for p_old, p_new in zip(old_m.parameters(), new_m.parameters()):
+            p_old.data = (1 - momentum) * p_new.data.clone() + momentum * p_old.data.clone()
+
+    def set_prev_mean_grad(self, mean_grad):
+        if self.prev_mean_grad is None:
+            self.prev_mean_grad = copy.deepcopy(mean_grad)
+        else:
+            self.set_model(self.prev_mean_grad, mean_grad)
+        
+    def set_prev_convex_comb_grad(self, convex_comb_grad, momentum=0.0):
+        if self.prev_convex_comb_grad is None:
+            self.prev_convex_comb_grad = copy.deepcopy(convex_comb_grad)
+        else:
+            self.set_model(self.prev_convex_comb_grad, convex_comb_grad, momentum=momentum)

+ 67 - 0
system/flcore/clients/clientrep.py

@@ -0,0 +1,67 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from sklearn.preprocessing import label_binarize
+from sklearn import metrics
+import copy
+from flcore.clients.clientbase import Client
+
+class clientRep(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.SGD(self.model.base.parameters(), lr=self.learning_rate)
+        self.poptimizer = torch.optim.SGD(self.model.predictor.parameters(), lr=self.learning_rate)
+
+        self.plocal_steps = args.plocal_steps
+
+    def train(self):
+        trainloader = self.load_train_data()
+
+        # self.model.to(self.device)
+        self.model.train()
+
+        for param in self.model.base.parameters():
+            param.requires_grad = False
+        for param in self.model.predictor.parameters():
+            param.requires_grad = True
+
+        for step in range(self.plocal_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.poptimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+                self.poptimizer.step()
+                
+        max_local_steps = self.local_steps
+
+        for param in self.model.base.parameters():
+            param.requires_grad = True
+        for param in self.model.predictor.parameters():
+            param.requires_grad = False
+
+        for step in range(max_local_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.optimizer.zero_grad()
+                output = self.model(x)
+                loss = self.criterion(output, y)
+                loss.backward()
+                self.optimizer.step()
+
+        # self.model.cpu()
+            
+    def set_parameters(self, base):
+        for new_param, old_param in zip(base.parameters(), self.model.base.parameters()):
+            old_param.data = new_param.data.clone()

+ 87 - 0
system/flcore/clients/clientrod.py

@@ -0,0 +1,87 @@
+import copy
+import torch
+import torch.nn as nn
+import numpy as np
+from flcore.clients.clientbase import Client
+import torch.nn.functional as F
+
+class RodEvalModel(nn.Module):
+    def __init__(self, glob_m, pers_pred):
+        super(RodEvalModel, self).__init__()
+        self.glob_m = glob_m
+        self.pers_pred = pers_pred
+    def forward(self, x):
+        rep = self.glob_m.base(x)
+        out = self.glob_m.predictor(rep)
+        out += self.pers_pred(rep)
+        return out
+
+class clientRoD(Client):
+    def __init__(self, args, id, train_samples, test_samples, **kwargs):
+        super().__init__(args, id, train_samples, test_samples, **kwargs)
+        
+        self.criterion = nn.CrossEntropyLoss()
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
+        
+        self.pred = copy.deepcopy(self.model.predictor)
+        self.opt_pred = torch.optim.SGD(self.pred.parameters(), lr=self.learning_rate)
+
+        self.sample_per_class = torch.zeros(self.num_classes)
+        trainloader = self.load_train_data()
+        for x, y in trainloader:
+            for yy in y:
+                self.sample_per_class[yy.item()] += 1
+        self.sample_per_class = self.sample_per_class / torch.sum(self.sample_per_class)
+
+
+    def train(self):
+        trainloader = self.load_train_data()
+
+        # self.model.to(self.device)
+        self.model.train()
+
+        max_local_steps = self.local_steps
+
+        for step in range(max_local_steps):
+            for i, (x, y) in enumerate(trainloader):
+                if type(x) == type([]):
+                    x[0] = x[0].to(self.device)
+                else:
+                    x = x.to(self.device)
+                y = y.to(self.device)
+                self.optimizer.zero_grad()
+                rep = self.model.base(x)
+                out_g = self.model.predictor(rep)
+                loss_bsm = balanced_softmax_loss(y, out_g, self.sample_per_class)
+                loss_bsm.backward()
+                self.optimizer.step()
+                
+                self.opt_pred.zero_grad()
+                out_p = self.pred(rep.detach())
+                loss = self.criterion(out_g.detach() + out_p, y)
+                loss.backward()
+                self.opt_pred.step()
+
+        # self.model.cpu()
+
+    # comment for testing on new clients
+    def get_eval_model(self, temp_model=None):
+        # temp_model is the current round global model (after aggregation)
+        return RodEvalModel(temp_model, self.pred)
+
+# https://github.com/jiawei-ren/BalancedMetaSoftmax-Classification
+def balanced_softmax_loss(labels, logits, sample_per_class, reduction="mean"):
+    """Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`.
+    Args:
+      labels: A int tensor of size [batch].
+      logits: A float tensor of size [batch, no_of_classes].
+      sample_per_class: A int tensor of size [no of classes].
+      reduction: string. One of "none", "mean", "sum"
+    Returns:
+      loss: A float tensor. Balanced Softmax Loss.
+    """
+    spc = sample_per_class.type_as(logits)
+    spc = spc.unsqueeze(0).expand(logits.shape[0], -1)
+    logits = logits + spc.log()
+    loss = F.cross_entropy(input=logits, target=labels, reduction=reduction)
+    return loss

+ 21 - 0
system/flcore/optimizers/fedoptimizer.py

@@ -0,0 +1,21 @@
+import random
+import torch
+from torch.optim import Optimizer
+
+class pFedMeOptimizer(Optimizer):
+    def __init__(self, params, local_model=None, lr=0.01, lambdaa=0.1, mu=0.001):
+        if lr < 0.0:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        defaults = dict(lr=lr, lambdaa=lambdaa, mu=mu)
+        super(pFedMeOptimizer, self).__init__(params, defaults)
+        self.weight_update = local_model.copy()
+    def step(self):
+        group = None
+        for group in self.param_groups:
+            for p, localweight in zip(group['params'], self.weight_update):
+                localweight = localweight.to(p)
+                # approximate local model
+                p.data = p.data - group['lr'] * (p.grad.data + group['lambdaa'] * (p.data - localweight.data) + group['mu'] * p.data)
+
+        return group['params']
+

+ 44 - 0
system/flcore/servers/serverapfl.py

@@ -0,0 +1,44 @@
+from flcore.clients.clientapfl import clientAPFL
+from flcore.servers.serverbase import Server
+import logging
+import os
+
+
+class APFL(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}, alpha:{args.alpha:.2f}"
+        clientObj = clientAPFL
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+            
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            for client in self.selected_clients:
+                client.train()
+
+            self.receive_models()
+            self.aggregate_parameters()
+
+            if i % self.eval_gap == 0:
+                print("==> Evaluating personalized models...", flush=True)
+                self.send_models(mode="all")
+                self.evaluate()
+                if i == 80:
+                    self.check_early_stopping()
+
+        print(f"\n==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)

+ 55 - 0
system/flcore/servers/serveravg.py

@@ -0,0 +1,55 @@
+import time
+from flcore.clients.clientavg import clientAVG
+from flcore.servers.serverbase import Server
+import os
+import logging
+import torch
+
+class FedAvg(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
+        clientObj = clientAVG
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+        self.last_ckpt_fn = os.path.join(self.ckpt_dir, f"FedAvg-cifar10-100clt.pt")
+
+        self.set_clients(args, clientObj)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+        self.Budget = []
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            for client in self.selected_clients:
+                client.train()
+
+            self.receive_models()
+            self.aggregate_parameters()
+
+            if i%self.eval_gap == 0:
+                print("==> Evaluating global models...", flush=True)
+                self.send_models(mode="all")
+                # self.evaluate(mode="global")
+                self.evaluate()
+                if i == 80:
+                    self.check_early_stopping()
+
+        print(f"==> Best mean accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+
+
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)
+        # state = {
+        #     "global_model": self.global_model.cpu().state_dict(),
+        #     "clients_test_accs": self.clients_test_accs[-1]
+        # }
+        # self.save_global_model(model_path=self.last_ckpt_fn, state=state)
+

+ 92 - 0
system/flcore/servers/serverbabu.py

@@ -0,0 +1,92 @@
+from flcore.clients.clientbabu import clientBABU
+from flcore.servers.serverbase import Server
+import torch
+import os
+import sys
+import logging
+
+
+class FedBABU(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}, alpha:{args.alpha:.5f}"
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientBABU)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+        # self.load_model()
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            for client in self.selected_clients:
+                client.train()
+            
+            self.receive_models()
+            self.aggregate_parameters()
+            if i%self.eval_gap == 0:
+                print("==> Evaluating global models...", flush=True)
+                self.send_models(mode="all")
+                self.evaluate(mode="global")
+                if i > 40:
+                    self.check_early_stopping()
+            
+        print("\n--------------------- Fine-tuning ----------------------")
+        self.send_fine_tune_models(mode="all")
+        for client in self.clients:
+            client.fine_tune()
+        print("------------- Evaluating fine-tuned models -------------")
+        self.evaluate(mode="personalized")
+        print(f"==> Mean personalized accuracy: {self.rs_test_acc[-1]*100:.2f}", flush=True)
+        message_res = f"\ttest_acc:{self.rs_test_acc[-1]:.6f}"
+        logging.info(self.message_hp + message_res)
+
+    def aggregate_parameters(self):
+        assert (len(self.uploaded_models) > 0)
+        for param in self.global_model.parameters():
+            param.data.zero_()
+            
+        for w, client_model in zip(self.uploaded_weights, self.uploaded_models):
+            # self.uploaded_models are a list of client.model.base's
+            self.add_parameters(w, client_model)
+        # after self.aggregate_parameters(), the self.global_model are still a model with base and predictor
+
+    def send_fine_tune_models(self, mode="selected"):
+        if mode == "selected":
+            assert (len(self.selected_clients) > 0)
+            for client in self.selected_clients:
+                client.set_fine_tune_parameters(self.global_model)
+        elif mode == "all":
+            for client in self.clients:
+                client.set_fine_tune_parameters(self.global_model)
+        else:
+            raise NotImplementedError
+
+    def receive_models(self):
+        assert (len(self.selected_clients) > 0)
+
+        self.uploaded_weights = []
+        tot_samples = 0
+        self.uploaded_ids = []
+        self.uploaded_models = []
+        for client in self.selected_clients:
+            self.uploaded_weights.append(client.train_samples)
+            tot_samples += client.train_samples
+            self.uploaded_ids.append(client.id)
+            self.uploaded_models.append(client.model.base)
+        for i, w in enumerate(self.uploaded_weights):
+            self.uploaded_weights[i] = w / tot_samples
+
+    def load_model(self, model_path=None):
+        if model_path is None:
+            model_path = os.path.join("models", self.dataset)
+            model_path = os.path.join(model_path, self.algorithm + "_server" + ".pt")
+        assert (os.path.exists(model_path))
+        return torch.load(model_path)

+ 273 - 0
system/flcore/servers/serverbase.py

@@ -0,0 +1,273 @@
+import torch
+import os
+import numpy as np
+import h5py
+import copy
+import time
+import sys
+import random
+import logging
+
+from utils.data_utils import read_client_data
+
+
+class Server(object):
+    def __init__(self, args, times):
+        # Set up the main attributes
+        self.device = args.device
+        self.dataset = args.dataset
+        self.global_rounds = args.global_rounds
+        self.local_steps = args.local_steps
+        self.batch_size = args.batch_size
+        self.learning_rate = args.local_learning_rate
+        self.global_model = copy.deepcopy(args.model)
+        
+        self.num_clients = args.num_clients
+        self.join_ratio = args.join_ratio
+        self.join_clients = int(self.num_clients * self.join_ratio)
+        self.algorithm = args.algorithm
+        self.goal = args.goal
+        self.top_cnt = 100
+        self.best_mean_test_acc = -1.0
+        self.clients = []
+        self.selected_clients = []
+
+        self.uploaded_weights = []
+        self.uploaded_ids = []
+        self.uploaded_models = []
+
+        self.rs_test_acc = []
+        self.rs_test_auc = []
+        self.rs_test_loss = []
+        self.rs_train_loss = []
+        self.clients_test_accs = []
+        self.domain_mean_test_accs = []
+
+        self.times = times
+        self.eval_gap = args.eval_gap
+
+        self.set_seed(self.times)
+        self.set_path(args)
+
+        # preprocess dataset name
+        if self.dataset.startswith("cifar"):
+            dir_alpha = 0.3
+        elif self.dataset == "organamnist25":
+            dir_alpha = 1.0
+        elif self.dataset.startswith("organamnist"):
+            dir_alpha = 0.3
+        elif self.dataset.startswith("organamnist"):
+            if self.num_clients == 20:
+                dir_alpha = 0.3
+            else:
+                dir_alpha = 1.0
+        else:
+            dir_alpha = float("nan")
+
+        self.actual_dataset = f"{self.dataset}-{self.num_clients}clients_alpha{dir_alpha:.1f}"
+        logger_fn = os.path.join(args.log_dir, f"{args.algorithm}-{self.actual_dataset}.log")
+        self.set_logger(save=True, fn=logger_fn)
+
+    def set_seed(self, seed):
+        np.random.seed(seed)
+        random.seed(seed)
+        torch.manual_seed(seed)
+        if torch.cuda.is_available():
+            torch.cuda.manual_seed(seed)
+            torch.cuda.manual_seed_all(seed)
+        torch.backends.cudnn.enabled = False
+        torch.backends.cudnn.benchmark = False
+        torch.backends.cudnn.deterministic = True
+
+    def set_logger(self, save=False, fn=None):
+        if save:
+            fn = "testlog.log" if fn == None else fn
+            logging.basicConfig(
+                filename=fn,
+                filemode="a",
+                format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+                level=logging.DEBUG
+            )
+        else:
+            logging.basicConfig(
+                format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+                level=logging.DEBUG
+            )
+
+    def set_path(self, args):
+        self.hist_dir = args.hist_dir
+        self.log_dir = args.log_dir
+        self.ckpt_dir = args.ckpt_dir
+        if not os.path.exists(args.hist_dir):
+            os.makedirs(args.hist_dir)
+        if not os.path.exists(args.log_dir):
+            os.makedirs(args.log_dir)
+        if not os.path.exists(args.ckpt_dir):
+            os.makedirs(args.ckpt_dir)
+
+    def set_clients(self, args, clientObj):
+        self.new_clients = None
+        for i in range(self.num_clients):
+            train_data = read_client_data(self.dataset, i, is_train=True)
+            test_data = read_client_data(self.dataset, i, is_train=False)
+            client = clientObj(args, 
+                            id=i, 
+                            train_samples=len(train_data), 
+                            test_samples=len(test_data))
+            self.clients.append(client)
+
+    def select_clients(self):
+        selected_clients = list(np.random.choice(self.clients, self.join_clients, replace=False))
+        return selected_clients
+
+    def send_models(self, mode="selected"):
+        if mode == "selected":
+            assert (len(self.selected_clients) > 0)
+            for client in self.selected_clients:
+                client.set_parameters(self.global_model)
+        elif mode == "all":
+            for client in self.clients:
+                client.set_parameters(self.global_model)
+        else:
+            raise NotImplementedError
+
+    def receive_models(self):
+        assert (len(self.selected_clients) > 0)
+
+        self.uploaded_weights = []
+        tot_samples = 0
+        self.uploaded_ids = []
+        self.uploaded_models = []
+        for client in self.selected_clients:
+            self.uploaded_weights.append(client.train_samples)
+            tot_samples += client.train_samples
+            self.uploaded_ids.append(client.id)
+            self.uploaded_models.append(client.model)
+        for i, w in enumerate(self.uploaded_weights):
+            self.uploaded_weights[i] = w / tot_samples
+
+    def aggregate_parameters(self):
+        assert (len(self.uploaded_models) > 0)
+
+        self.global_model = copy.deepcopy(self.uploaded_models[0])
+        for param in self.global_model.parameters():
+            param.data.zero_()
+            
+        for w, client_model in zip(self.uploaded_weights, self.uploaded_models):
+            self.add_parameters(w, client_model)
+
+    def add_parameters(self, w, client_model):
+        for server_param, client_param in zip(self.global_model.parameters(), client_model.parameters()):
+            server_param.data += client_param.data.clone() * w
+
+    def prepare_global_model(self):
+        pass
+
+    def reset_records(self):
+        self.best_mean_test_acc = 0.0
+        self.clients_test_accs = []
+        self.rs_test_acc = []
+        self.rs_test_auc = []
+        self.rs_test_loss = []
+
+    def train_new_clients(self, epochs=20):
+        self.global_model = self.global_model.to(self.device)
+        self.clients = self.new_clients
+        self.reset_records()
+        for c in self.clients:
+            c.model = copy.deepcopy(self.global_model)
+        self.evaluate()
+        for epoch_idx in range(epochs):
+            for c in self.clients:
+                c.standard_train()
+            print(f"==> New clients epoch: [{epoch_idx+1:2d}/{epochs}] | Evaluating local models...", flush=True)
+            self.evaluate()
+        print(f"==> Best mean global accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\tnew_clients_test_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)
+
+    def save_global_model(self, model_path=None, state=None):
+        if model_path is None:
+            model_path = os.path.join("models", self.dataset)
+            if not os.path.exists(model_path):
+                os.makedirs(model_path)
+            model_path = os.path.join(model_path, self.algorithm + "_server" + ".pt")
+        if state is None:
+            torch.save({"global_model": self.global_model.cpu().state_dict()}, model_path)
+        else:
+            torch.save(state, model_path)
+
+    def load_model(self, model_path=None):
+        if model_path is None:
+            model_path = os.path.join("models", self.dataset)
+            model_path = os.path.join(model_path, self.algorithm + "_server" + ".pt")
+        assert (os.path.exists(model_path))
+        self.global_model = torch.load(model_path)
+        
+    def save_results(self, fn=None):
+        if fn is None:
+            algo = self.dataset + "_" + self.algorithm
+            result_path = self.hist_dir
+
+        if (len(self.rs_test_acc)):
+            if fn is None:
+                algo = algo + "_" + self.goal + "_" + str(self.times+1)
+                file_path = os.path.join(result_path, "{}.h5".format(algo))
+            else:
+                file_path = fn
+            print("File path: " + file_path)
+
+            with h5py.File(file_path, 'w') as hf:
+                hf.create_dataset('rs_test_acc', data=self.rs_test_acc)
+                hf.create_dataset('rs_test_auc', data=self.rs_test_auc)
+                hf.create_dataset('rs_test_loss', data=self.rs_test_loss)
+                hf.create_dataset('clients_test_accs', data=self.clients_test_accs)
+                # hf.create_dataset('rs_train_loss', data=self.rs_train_loss)
+
+
+    def test_metrics(self, temp_model=None):
+        """ A personalized evaluation scheme (test_acc's do not average based on num_samples) """
+        test_accs, test_aucs, test_losses, test_nums = [], [], [], []
+        for c in self.clients:
+            test_acc, test_auc, test_loss, test_num = c.test_metrics(temp_model)  # test_acc, test_num, test_auc
+            test_accs.append(test_acc)
+            test_aucs.append(test_auc)
+            test_losses.append(test_loss)
+            test_nums.append(test_num)
+        ids = [c.id for c in self.clients]
+        return ids, test_accs, test_aucs, test_losses, test_nums
+
+    # evaluate selected clients
+    def evaluate(self, temp_model=None, mode="personalized"):
+        ids, test_accs, test_aucs, test_losses, test_nums = self.test_metrics(temp_model)
+        self.clients_test_accs.append(copy.deepcopy(test_accs))
+        if mode == "personalized":
+            mean_test_acc, mean_test_auc, mean_test_loss = np.mean(test_accs), np.mean(test_aucs), np.mean(test_losses)
+        elif mode == "global":
+            mean_test_acc, mean_test_auc, mean_test_loss = np.average(test_accs, weights=test_nums), np.average(test_aucs, weights=test_nums), np.average(test_losses, weights=test_nums)
+        else:
+            raise NotImplementedError
+        # compute domain means for
+        if self.dataset.startswith("Office-home") and (mean_test_acc > self.best_mean_test_acc):
+            self.best_mean_test_acc = mean_test_acc
+            self.domain_mean_test_accs = np.mean(np.array(test_accs).reshape(4, -1), axis=1)
+        self.best_mean_test_acc = max(mean_test_acc, self.best_mean_test_acc)
+        self.rs_test_acc.append(mean_test_acc)
+        self.rs_test_auc.append(mean_test_auc)
+        self.rs_test_loss.append(mean_test_loss)
+        print(f"==> test_loss: {mean_test_loss:.5f} | mean_test_accs: {mean_test_acc*100:.2f}% | best_acc: {self.best_mean_test_acc*100:.2f}%\n")
+
+    def check_early_stopping(self, thresh=0.0):
+        # Early stopping
+        if thresh == 0.0:
+            if (self.dataset == "cifar100"):
+                thresh = 0.2
+            elif (self.dataset == "cifar10"):
+                thresh = 0.6
+            elif (self.dataset.startswith("organamnist")):
+                thresh = 0.8
+            else:
+                thresh = 0.23
+        return (self.rs_test_acc[-1] < thresh) and (self.rs_test_acc[-2] < thresh) and (self.rs_test_acc[-3] < thresh) and (self.rs_test_acc[-4] < thresh)
+

+ 87 - 0
system/flcore/servers/serverdyn.py

@@ -0,0 +1,87 @@
+import copy
+import torch
+from flcore.clients.clientdyn import clientDyn
+from flcore.servers.serverbase import Server
+import os
+import logging
+
+class FedDyn(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}, alpha:{args.alpha:.5f}"
+        clientObj = clientDyn
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+        # self.load_model()
+        self.Budget = []
+
+        self.alpha = args.alpha
+        
+        self.server_state = copy.deepcopy(args.model)
+        for param in self.server_state.parameters():
+            param.data.zero_()
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            for client in self.selected_clients:
+                client.train()
+
+            self.receive_models()
+            self.update_server_state()
+            self.aggregate_parameters()
+
+            if i%self.eval_gap == 0:
+                print("==> Evaluating global models...", flush=True)
+                self.send_models(mode="all")
+                # self.evaluate(mode="global")
+                self.evaluate()
+                if i == 80:
+                    self.check_early_stopping()
+
+        print(f"==> Best mean global accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)
+        # self.save_global_model()
+
+    def add_parameters(self, client_model):
+        for server_param, client_param in zip(self.global_model.parameters(), client_model.parameters()):
+            server_param.data += client_param.data.clone() / self.join_clients
+
+    def aggregate_parameters(self):
+        assert (len(self.uploaded_models) > 0)
+
+        self.global_model = copy.deepcopy(self.uploaded_models[0])
+        for param in self.global_model.parameters():
+            param.data.zero_()
+            
+        for client_model in self.uploaded_models:
+            self.add_parameters(client_model)
+
+        for server_param, state_param in zip(self.global_model.parameters(), self.server_state.parameters()):
+            server_param.data -= (1/self.alpha) * state_param.data
+
+    def update_server_state(self):
+        assert (len(self.uploaded_models) > 0)
+
+        model_delta = copy.deepcopy(self.uploaded_models[0])
+        for param in model_delta.parameters():
+            param.data.zero_()
+
+        for client_model in self.uploaded_models:
+            for server_param, client_param, delta_param in zip(self.global_model.parameters(), client_model.parameters(), model_delta.parameters()):
+                delta_param.data += (client_param - server_param) / self.num_clients
+
+        for state_param, delta_param in zip(self.server_state.parameters(), model_delta.parameters()):
+            state_param.data -= self.alpha * delta_param

+ 97 - 0
system/flcore/servers/serverfomo.py

@@ -0,0 +1,97 @@
+import torch
+import copy
+import random
+import os
+import logging
+import numpy as np
+from flcore.clients.clientfomo import clientFomo
+from flcore.servers.serverbase import Server
+
+
+class FedFomo(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
+        clientObj = clientFomo
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        self.P = torch.diag(torch.ones(self.num_clients, device=self.device))
+        self.uploaded_models = [self.global_model]
+        self.uploaded_ids = []
+        self.M = min(args.M, self.join_clients)
+            
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+
+            for client in self.selected_clients:
+                client.train()
+
+            self.receive_models()
+            # self.aggregate_parameters()
+
+            if i%self.eval_gap == 0:
+                print("==> Evaluating personalized model")
+                self.evaluate()
+                if i == 80:
+                    self.check_early_stopping()
+        print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+
+
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)
+        # self.save_global_model()
+
+
+    def send_models(self):
+        assert (len(self.selected_clients) > 0)
+        for client in self.selected_clients:
+
+            if len(self.uploaded_ids) > 0:
+                M_ = min(self.M, len(self.uploaded_models)) # if clients dropped
+                indices = torch.topk(self.P[client.id][self.uploaded_ids], M_).indices.tolist()
+
+                uploaded_ids = []
+                uploaded_models = []
+                for i in indices:
+                    uploaded_ids.append(self.uploaded_ids[i])
+                    uploaded_models.append(self.uploaded_models[i])
+
+                client.receive_models(uploaded_ids, uploaded_models)
+
+    def prepare_global_model(self):
+        self.global_model = copy.deepcopy(self.clients[0].model)
+        for p in self.global_model.parameters():
+            p.data.zero_()
+        for c in self.clients:
+            self.add_parameters(c.train_samples, c.model)
+        return
+    
+    def receive_models(self):
+        assert (len(self.selected_clients) > 0)
+
+        active_clients = random.sample(self.selected_clients, self.join_clients)
+
+        self.uploaded_ids = []
+        self.uploaded_weights = []
+        tot_samples = 0
+        self.uploaded_models = []
+        for client in active_clients:
+            self.uploaded_ids.append(client.id)
+            self.uploaded_weights.append(client.train_samples)
+            tot_samples += client.train_samples
+            self.uploaded_models.append(copy.deepcopy(client.model))
+            self.P[client.id] += client.weight_vector
+        for i, w in enumerate(self.uploaded_weights):
+            self.uploaded_weights[i] = w / tot_samples
+            

+ 74 - 0
system/flcore/servers/serverlgfedavg.py

@@ -0,0 +1,74 @@
+from flcore.clients.clientlgfedavg import clientLGFedAvg
+from flcore.servers.serverbase import Server
+import copy
+import os
+import logging
+
+class LGFedAvg(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
+        clientObj = clientLGFedAvg
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+        # self.load_model()
+        self.global_model = self.global_model.predictor
+
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            for client in self.selected_clients:
+                client.train()
+
+            self.receive_models()
+            self.aggregate_parameters()
+
+            if i%self.eval_gap == 0:
+                print("==> Evaluating personalized models...", flush=True)
+                self.send_models(mode="all")
+                self.evaluate(self.global_model)
+
+        print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)
+
+
+    def receive_models(self):
+        assert (len(self.selected_clients) > 0)
+
+        self.uploaded_weights = []
+        tot_samples = 0
+        self.uploaded_ids = []
+        self.uploaded_models = []
+        for client in self.selected_clients:
+            self.uploaded_weights.append(client.train_samples)
+            tot_samples += client.train_samples
+            self.uploaded_ids.append(client.id)
+            self.uploaded_models.append(client.model.predictor)
+        for i, w in enumerate(self.uploaded_weights):
+            self.uploaded_weights[i] = w / tot_samples
+
+    def prepare_global_model(self):
+        temp_model = copy.deepcopy(self.global_model) # predictor
+        self.global_model = copy.deepcopy(self.clients[0].model)
+        for p_t, p_g in zip(temp_model.parameters(), self.global_model.predictor.parameters()):
+            p_g.data = p_t.data.clone()
+        for p in self.global_model.base.parameters():
+            p.data.zero_()
+        for c in self.clients:
+            for p_g, p_c in zip(self.global_model.base.parameters(), c.model.base.parameters()):
+                p_g.data += p_c.data * c.train_samples
+        return

+ 40 - 0
system/flcore/servers/serverlocal.py

@@ -0,0 +1,40 @@
+from flcore.clients.clientavg import clientAVG
+from flcore.servers.serverbase import Server
+import os
+import logging
+
+
+class Local(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
+        clientObj = clientAVG
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+        
+        # self.load_model()
+
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            for client in self.selected_clients:
+                client.train()
+
+            if i%self.eval_gap == 0:
+                print(f"\n-------------Round number: {i}-------------")
+                print("\nEvaluate local model")
+                self.evaluate()
+
+        print(f"==> Best accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)
+

+ 72 - 0
system/flcore/servers/serverper.py

@@ -0,0 +1,72 @@
+from flcore.clients.clientper import clientPer
+from flcore.servers.serverbase import Server
+import copy
+import os
+import logging
+
+class FedPer(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
+        clientObj = clientPer
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+        # self.load_model()
+
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            for client in self.selected_clients:
+                client.train()
+            self.receive_models()
+            self.aggregate_parameters()
+            
+            if i%self.eval_gap == 0:
+                print("==> Evaluating personalized models...", flush=True)
+                self.send_models(mode="all")
+                self.evaluate(self.global_model)
+
+        print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)
+
+
+    def receive_models(self):
+        assert (len(self.selected_clients) > 0)
+
+        self.uploaded_weights = []
+        tot_samples = 0
+        self.uploaded_ids = []
+        self.uploaded_models = []
+        for client in self.selected_clients:
+            self.uploaded_weights.append(client.train_samples)
+            tot_samples += client.train_samples
+            self.uploaded_ids.append(client.id)
+            self.uploaded_models.append(client.model.base)
+        for i, w in enumerate(self.uploaded_weights):
+            self.uploaded_weights[i] = w / tot_samples
+
+    def prepare_global_model(self):
+        temp_model = copy.deepcopy(self.global_model) # base
+        self.global_model = copy.deepcopy(self.clients[0].model)
+        for p_t, p_g in zip(temp_model.parameters(), self.global_model.base.parameters()):
+            p_g.data = p_t.data.clone()
+        for p in self.global_model.predictor.parameters():
+            p.data.zero_()
+        for c in self.clients:
+            for p_g, p_c in zip(self.global_model.predictor.parameters(), c.model.predictor.parameters()):
+                p_g.data += p_c.data * c.train_samples
+        return

+ 42 - 0
system/flcore/servers/serverperavg.py

@@ -0,0 +1,42 @@
+import copy
+import os
+import logging
+import torch
+from flcore.clients.clientperavg import clientPerAvg
+from flcore.servers.serverbase import Server
+
+
+class PerAvg(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, alpha:{args.alpha:.5f}, beta:{args.beta:.5f}"
+        clientObj = clientPerAvg
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            # send all parameter for clients
+            self.send_models()
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            for client in self.selected_clients:
+                client.train()
+            self.receive_models()
+            self.aggregate_parameters()
+
+            if i%self.eval_gap == 0:
+                print("==> Evaluating personalized models...", flush=True)
+                self.send_models(mode="all")
+                self.evaluate(self.global_model)
+
+        print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)

+ 56 - 0
system/flcore/servers/serverpfedme.py

@@ -0,0 +1,56 @@
+import os
+import logging
+import copy
+import h5py
+from flcore.clients.clientpfedme import clientpFedMe
+from flcore.servers.serverbase import Server
+
+
+class pFedMe(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
+        clientObj = clientpFedMe
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        self.beta = args.beta
+        self.rs_train_acc_per = []
+        self.rs_train_loss_per = []
+        self.rs_test_acc_per = []
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+            
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+
+            for client in self.selected_clients:
+                client.train()
+
+
+            if i%self.eval_gap == 0:
+                print("==> Evaluating personalized models...")
+                self.evaluate()
+
+            self.previous_global_model = copy.deepcopy(list(self.global_model.parameters()))
+            self.receive_models()
+            self.aggregate_parameters()
+            self.beta_aggregate_parameters()
+
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)
+        # self.save_global_model()
+
+    def beta_aggregate_parameters(self):
+        # aggregate avergage model with previous model using parameter beta
+        for pre_param, param in zip(self.previous_global_model, self.global_model.parameters()):
+            param.data = (1 - self.beta)*pre_param.data + self.beta*param.data

+ 129 - 0
system/flcore/servers/serverpgfed.py

@@ -0,0 +1,129 @@
+import copy
+from flcore.clients.clientpgfed import clientPGFed
+from flcore.servers.serverbase import Server
+import numpy as np
+import torch
+import h5py
+import os
+import logging
+
+
+class PGFed(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}, mu:{args.mu:.5f}, lambda:{args.lambdaa:.5f}"
+        if self.algorithm == "PGFedMo":
+            self.momentum = args.beta
+            self.message_hp += f", beta:{args.beta:.5f}" # momentum
+        else:
+            self.momentum = 0.0
+        clientObj = clientPGFed
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        self.mu = args.mu
+        self.alpha_mat = (torch.ones((self.num_clients, self.num_clients)) / self.join_clients).to(self.device)
+        self.uploaded_grads = {}
+        self.loss_minuses = {}
+        self.mean_grad = None
+        self.convex_comb_grad = None
+        self.best_global_mean_test_acc = 0.0
+        self.rs_global_test_acc = []
+        self.rs_global_test_auc = []
+        self.rs_global_test_loss = []
+        self.last_ckpt_fn = os.path.join(self.ckpt_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.pt")
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+    def train(self):
+        early_stop = False
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            self.send_models()
+            for client in self.selected_clients:
+                client.train()
+
+            self.receive_models()
+            self.aggregate_parameters()
+
+            if i%self.eval_gap == 0:
+                print("==> Evaluating personalized models...", flush=True)
+                self.evaluate()
+                if i >= 40 and self.check_early_stopping():
+                    early_stop = True
+                    print("==> Performance is too low. Excecuting early stop.")
+                    break
+
+        print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+        if not early_stop:
+            self.save_results(fn=self.hist_result_fn)
+            # message_res = f"\tglobal_test_acc:{self.best_global_mean_test_acc:.6f}\ttest_acc:{self.best_mean_test_acc:.6f}"
+            message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+            logging.info(self.message_hp + message_res)
+            # state = {
+            #         "model": self.global_model.cpu().state_dict(),
+            #         # "best_global_acc": self.best_global_mean_test_acc,
+            #         "best_personalized_acc": self.best_mean_test_acc,
+            #         "alpha_mat": self.alpha_mat.cpu()
+            #     }
+            # state.update({f"client{c.id}": c.model.cpu().state_dict() for c in self.clients})
+            # self.save_global_model(model_path=self.last_ckpt_fn, state=state)
+
+    def receive_models(self):
+        assert (len(self.selected_clients) > 0)
+        self.uploaded_ids = []
+        self.uploaded_grads = {}
+        self.loss_minuses = {}
+        self.uploaded_models = []
+        self.uploaded_weights = []
+        tot_samples = 0
+        for client in self.selected_clients:
+            self.uploaded_ids.append(client.id)
+            self.alpha_mat[client.id] = client.a_i
+            self.uploaded_grads[client.id] = client.latest_grad
+            self.loss_minuses[client.id] = client.loss_minus * self.mu
+
+            self.uploaded_weights.append(client.train_samples)
+            tot_samples += client.train_samples
+            self.uploaded_models.append(client.model)
+        for i, w in enumerate(self.uploaded_weights):
+            self.uploaded_weights[i] = w / tot_samples
+
+    def aggregate_parameters(self):
+        assert (len(self.uploaded_grads) > 0)
+        self.model_weighted_sum(self.global_model, self.uploaded_models, self.uploaded_weights)
+        w = self.mu/self.join_clients
+        weights = [w for _ in range(self.join_clients)]
+        self.mean_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
+        self.model_weighted_sum(self.mean_grad, list(self.uploaded_grads.values()), weights)
+
+    def model_weighted_sum(self, model, models, weights):
+        for p_m in model.parameters():
+            p_m.data.zero_()
+        for w, m_i in zip(weights, models):
+            for p_m, p_i in zip(model.parameters(), m_i.parameters()):
+                p_m.data += p_i.data.clone() * w
+
+    def send_models(self, mode="selected"):
+        assert (len(self.selected_clients) > 0)
+        for client in self.selected_clients:
+            client.a_i = self.alpha_mat[client.id]
+            client.set_parameters(self.global_model)
+        if len(self.uploaded_grads) == 0:
+            return
+        self.convex_comb_grad = copy.deepcopy(list(self.uploaded_grads.values())[0])
+        for client in self.selected_clients:
+            client.set_prev_mean_grad(self.mean_grad)
+            mu_a_i = self.alpha_mat[client.id] * self.mu
+            grads, weights = [], []
+            for clt_idx, grad in self.uploaded_grads.items():
+                weights.append(mu_a_i[clt_idx])
+                grads.append(grad)
+            self.model_weighted_sum(self.convex_comb_grad, grads, weights)
+            client.set_prev_convex_comb_grad(self.convex_comb_grad, momentum=self.momentum)
+            client.prev_loss_minuses = copy.deepcopy(self.loss_minuses)

+ 90 - 0
system/flcore/servers/serverrep.py

@@ -0,0 +1,90 @@
+from flcore.clients.clientrep import clientRep
+from flcore.servers.serverbase import Server
+import os
+import logging
+import copy
+
+class FedRep(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
+        clientObj = clientRep
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+        self.Budget = []
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+            
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+            for client in self.selected_clients:
+                client.train()
+            self.receive_models()
+            self.aggregate_parameters()
+
+            if i%self.eval_gap == 0:
+                print("==> Evaluating personalized models...", flush=True)
+                self.send_models(mode="all")
+                self.evaluate(self.global_model)
+                if i == 80:
+                    self.check_early_stopping()
+
+        print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)
+        # self.save_global_model()
+
+    def receive_models(self):
+        assert (len(self.selected_clients) > 0)
+
+        active_train_samples = 0
+        for client in self.selected_clients:
+            active_train_samples += client.train_samples
+
+        self.uploaded_weights = []
+        self.uploaded_ids = []
+        self.uploaded_models = []
+        for client in self.selected_clients:
+            self.uploaded_weights.append(client.train_samples / active_train_samples)
+            self.uploaded_ids.append(client.id)
+            self.uploaded_models.append(copy.deepcopy(client.model.base))
+
+    def prepare_global_model(self):
+        temp_model = copy.deepcopy(self.global_model) # base
+        self.global_model = copy.deepcopy(self.clients[0].model)
+        for p_t, p_g in zip(temp_model.parameters(), self.global_model.base.parameters()):
+            p_g.data = p_t.data.clone()
+        for p in self.global_model.predictor.parameters():
+            p.data.zero_()
+        for c in self.clients:
+            for p_g, p_c in zip(self.global_model.predictor.parameters(), c.model.predictor.parameters()):
+                p_g.data += p_c.data * c.train_samples
+        return
+
+    def train_new_clients(self, epochs=20):
+        self.global_model = self.global_model.to(self.device)
+        self.clients = self.new_clients
+        self.send_models(mode="all")
+        self.reset_records()
+        for c in self.clients:
+            c.model = copy.deepcopy(self.global_model)
+        for epoch_idx in range(epochs):
+            for c in self.clients:
+                c.standard_train()
+            print(f"==> New clients epoch: [{epoch_idx+1:2d}/{epochs}] | Evaluating local models...", flush=True)
+            self.evaluate()
+        print(f"==> Best mean global accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\tnew_clients_test_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)

+ 45 - 0
system/flcore/servers/serverrod.py

@@ -0,0 +1,45 @@
+from flcore.clients.clientrod import clientRoD
+from flcore.servers.serverbase import Server
+from utils.data_utils import read_client_data
+import os
+import logging
+
+
+class FedRoD(Server):
+    def __init__(self, args, times):
+        super().__init__(args, times)
+        self.message_hp = f"{args.algorithm}, lr:{args.local_learning_rate:.5f}"
+        clientObj = clientRoD
+        self.message_hp_dash = self.message_hp.replace(", ", "-")
+        self.hist_result_fn = os.path.join(args.hist_dir, f"{self.actual_dataset}-{self.message_hp_dash}-{args.goal}-{self.times}.h5")
+
+        self.set_clients(args, clientObj)
+
+        print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}")
+        print("Finished creating server and clients.")
+
+        # self.load_model()
+
+    def train(self):
+        for i in range(self.global_rounds):
+            self.selected_clients = self.select_clients()
+            self.send_models()
+
+            print(f"\n------------- Round number: [{i+1:3d}/{self.global_rounds}]-------------")
+            print(f"==> Training for {len(self.selected_clients)} clients...", flush=True)
+
+            for client in self.selected_clients:
+                client.train()
+
+            self.receive_models()
+            self.aggregate_parameters()
+
+            if i%self.eval_gap == 0:
+                print("==> Evaluating personalized models...", flush=True)
+                self.evaluate(self.global_model)
+
+        print(f"==> Best mean personalized accuracy: {self.best_mean_test_acc*100:.2f}%", flush=True)
+
+        self.save_results(fn=self.hist_result_fn)
+        message_res = f"\ttest_acc:{self.best_mean_test_acc:.6f}"
+        logging.info(self.message_hp + message_res)

+ 58 - 0
system/flcore/trainmodel/models.py

@@ -0,0 +1,58 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+batch_size = 16
+
+
+class LocalModel(nn.Module):
+    def __init__(self, base, predictor):
+        super(LocalModel, self).__init__()
+
+        self.base = base
+        self.predictor = predictor
+        
+    def forward(self, x):
+        out = self.base(x)
+        out = self.predictor(out)
+
+        return out
+        
+
+# https://github.com/katsura-jp/fedavg.pytorch/blob/master/src/models/cnn.py
+class FedAvgCNN(nn.Module):
+    def __init__(self, in_features=1, num_classes=10, dim=1024):
+        super().__init__()
+        self.conv1 = nn.Sequential(
+            nn.Conv2d(in_features,
+                        32,
+                        kernel_size=5,
+                        padding=0,
+                        stride=1,
+                        bias=True),
+            nn.ReLU(inplace=True), 
+            nn.MaxPool2d(kernel_size=(2, 2))
+        )
+        self.conv2 = nn.Sequential(
+            nn.Conv2d(32,
+                        64,
+                        kernel_size=5,
+                        padding=0,
+                        stride=1,
+                        bias=True),
+            nn.ReLU(inplace=True), 
+            nn.MaxPool2d(kernel_size=(2, 2))
+        )
+        self.fc1 = nn.Sequential(
+            nn.Linear(dim, 512), 
+            nn.ReLU(inplace=True)
+        )
+        self.fc = nn.Linear(512, num_classes)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.conv2(out)
+        out = torch.flatten(out, 1)
+        out = self.fc1(out)
+        out = self.fc(out)
+        return out

+ 212 - 0
system/main.py

@@ -0,0 +1,212 @@
+import copy
+import torch
+import argparse
+import os
+import time
+import warnings
+import numpy as np
+import torchvision
+
+from flcore.trainmodel.models import *
+
+warnings.simplefilter("ignore")
+torch.manual_seed(0)
+
+def run(args):
+    model_str = args.model
+    for i in range(args.prev, args.times):
+        print(f"\n============= Running time: [{i+1}th/{args.times}] =============", flush=True)
+        print("Creating server and clients ...")
+
+        # Generate args.model
+        if model_str == "cnn":
+            if args.dataset == "mnist" or args.dataset.startswith("organamnist"):
+                args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
+            elif args.dataset.upper() == "CIFAR10" or args.dataset.upper() == "CIFAR100" or args.dataset.startswith("Office-home"):
+                args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
+            else:
+                args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
+
+        else:
+            raise NotImplementedError
+
+        # select algorithm
+        if args.algorithm.startswith("Local"):
+            from flcore.servers.serverlocal import Local
+            server = Local(args, i)
+
+        elif args.algorithm.startswith("FedAvg"):
+            from flcore.servers.serveravg import FedAvg
+            server = FedAvg(args, i)
+
+        elif args.algorithm.startswith("FedDyn"):
+            from flcore.servers.serverdyn import FedDyn
+            server = FedDyn(args, i)
+
+        elif args.algorithm.startswith("pFedMe"):
+            from flcore.servers.serverpfedme import pFedMe
+            server = pFedMe(args, i)
+
+        elif args.algorithm.startswith("FedFomo"):
+            from flcore.servers.serverfomo import FedFomo
+            server = FedFomo(args, i)
+
+        elif args.algorithm.startswith("APFL"):
+            from flcore.servers.serverapfl import APFL
+            server = APFL(args, i)
+
+        elif args.algorithm.startswith("FedRep"):
+            from flcore.servers.serverrep import FedRep
+            args.predictor = copy.deepcopy(args.model.fc)
+            args.model.fc = nn.Identity()
+            args.model = LocalModel(args.model, args.predictor)
+            server = FedRep(args, i)
+
+        elif args.algorithm.startswith("LGFedAvg"):
+            from flcore.servers.serverlgfedavg import LGFedAvg
+            args.predictor = copy.deepcopy(args.model.fc)
+            args.model.fc = nn.Identity()
+            args.model = LocalModel(args.model, args.predictor)
+            server = LGFedAvg(args, i)
+
+        elif args.algorithm.startswith("FedPer"):
+            from flcore.servers.serverper import FedPer
+            args.predictor = copy.deepcopy(args.model.fc)
+            args.model.fc = nn.Identity()
+            args.model = LocalModel(args.model, args.predictor)
+            server = FedPer(args, i)
+
+        elif args.algorithm.startswith("PerAvg"):
+            from flcore.servers.serverperavg import PerAvg
+            server = PerAvg(args, i)
+
+        elif args.algorithm.startswith("FedRoD"):
+            from flcore.servers.serverrod import FedRoD
+            args.predictor = copy.deepcopy(args.model.fc)
+            args.model.fc = nn.Identity()
+            args.model = LocalModel(args.model, args.predictor)
+            server = FedRoD(args, i)
+
+        elif args.algorithm.startswith("FedBABU"):
+            args.predictor = copy.deepcopy(args.model.fc)
+            args.model.fc = nn.Identity()
+            args.model = LocalModel(args.model, args.predictor)
+            from flcore.servers.serverbabu import FedBABU
+            server = FedBABU(args, i)
+
+        elif args.algorithm.startswith("PGFed"):
+            from flcore.servers.serverpgfed import PGFed
+            server = PGFed(args, i)
+            
+        else:
+            raise NotImplementedError
+
+
+
+        server.train()
+        if args.dataset.startswith("Office-home") and args.times != 1:
+            import logging
+            m = server.domain_mean_test_accs
+            logging.info(f"domains means and average:\t{m[0]:.6f}\t{m[1]:.6f}\t{m[2]:.6f}\t{m[3]:.6f}\t{server.best_mean_test_acc:.6f}")
+
+
+
+        # # comment the above block and uncomment the following block for fine-tuning on new clients
+        # if len(server.clients) == 100:
+        #     old_clients_num = 80
+        #     server.new_clients = server.clients[old_clients_num:]
+        #     server.clients = server.clients[:old_clients_num]
+        #     server.num_clients = old_clients_num
+        #     server.join_clients = int(old_clients_num * server.join_ratio)
+        # if not args.algorithm.startswith("Local"):
+        #     server.train()
+        #     server.prepare_global_model()
+        # n_epochs = 20
+        # print(f"\n\n==> Training for new clients for {n_epochs} epochs")
+        # server.train_new_clients(epochs=n_epochs)
+
+def get_args():
+    parser = argparse.ArgumentParser()
+    # general
+    parser.add_argument('-go', "--goal", type=str, default="cnn", 
+                        help="The goal for this experiment")
+    parser.add_argument('-dev', "--device", type=str, default="cuda",
+                        choices=["cpu", "cuda"])
+    parser.add_argument('-did', "--device_id", type=str, default="0")
+    parser.add_argument('-data', "--dataset", type=str, default="cifar10",
+                        choices=["cifar10", "cifar100", "organaminist25", "organaminist50", "organaminist100", "Office-home20"])
+    parser.add_argument('-nb', "--num_classes", type=int, default=10)
+    parser.add_argument('-m', "--model", type=str, default="cnn")
+    parser.add_argument('-p', "--predictor", type=str, default="cnn")
+    parser.add_argument('-lbs', "--batch_size", type=int, default=10)
+    parser.add_argument('-lr', "--local_learning_rate", type=float, default=0.005,
+                        help="Local learning rate")
+    parser.add_argument('-gr', "--global_rounds", type=int, default=3)
+    parser.add_argument('-ls', "--local_steps", type=int, default=5)
+    parser.add_argument('-algo', "--algorithm", type=str, default="PGFed")
+    parser.add_argument('-jr', "--join_ratio", type=float, default=0.25,
+                        help="Ratio of clients per round")
+    parser.add_argument('-nc', "--num_clients", type=int, default=25,
+                        help="Total number of clients")
+    parser.add_argument('-pv', "--prev", type=int, default=0,
+                        help="Previous Running times")
+    parser.add_argument('-t', "--times", type=int, default=1,
+                        help="Running times")
+    parser.add_argument('-eg', "--eval_gap", type=int, default=1,
+                        help="Rounds gap for evaluation")
+
+    # FL algorithms (multiple algs)
+    parser.add_argument('-bt', "--beta", type=float, default=0.0,
+                        help="PGFed momentum, average moving parameter for pFedMe, Second learning rate of Per-FedAvg")
+    parser.add_argument('-lam', "--lambdaa", type=float, default=1.0,
+                        help="PGFed learning rate for a_i, Regularization weight for pFedMe")
+    parser.add_argument('-mu', "--mu", type=float, default=0,
+                        help="PGFed weight for aux risk, pFedMe weight")
+    parser.add_argument('-K', "--K", type=int, default=5,
+                        help="Number of personalized training steps for pFedMe")
+    parser.add_argument('-lrp', "--p_learning_rate", type=float, default=0.01,
+                        help="pFedMe personalized learning rate to caculate theta aproximately using K steps")
+    # FedFomo
+    parser.add_argument('-M', "--M", type=int, default=8,
+                        help="Server only sends M client models to one client at each round")
+    # APFL
+    parser.add_argument('-al', "--alpha", type=float, default=0.5)
+    # FedRep
+    parser.add_argument('-pls', "--plocal_steps", type=int, default=5)
+    # FedBABU
+    parser.add_argument('-fts', "--fine_tuning_steps", type=int, default=1)
+    # save directories
+    parser.add_argument("--hist_dir", type=str, default="../", help="dir path for output hist file")
+    parser.add_argument("--log_dir", type=str, default="../", help="dir path for log (main results) file")
+    parser.add_argument("--ckpt_dir", type=str, default="../", help="dir path for checkpoints")
+
+    args = parser.parse_args()
+    return args
+
+if __name__ == "__main__":
+    total_start = time.time()
+    args = get_args()
+    # os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
+
+    if args.device == "cuda" and not torch.cuda.is_available():
+        print("\ncuda is not avaiable.\n")
+        args.device = "cpu"
+    print("=" * 50)
+
+    print("Algorithm: {}".format(args.algorithm))
+    print("Local batch size: {}".format(args.batch_size))
+    print("Local steps: {}".format(args.local_steps))
+    print("Local learing rate: {}".format(args.local_learning_rate))
+    print("Total number of clients: {}".format(args.num_clients))
+    print("Clients join in each round: {}".format(args.join_ratio))
+    print("Global rounds: {}".format(args.global_rounds))
+    print("Running times: {}".format(args.times))
+    print("Dataset: {}".format(args.dataset))
+    print("Local model: {}".format(args.model))
+    print("Using device: {}".format(args.device))
+
+    if args.device == "cuda":
+        print("Cuda device id: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
+    print("=" * 50)
+
+    run(args)

+ 48 - 0
system/traincifar10_25clt_example.sh

@@ -0,0 +1,48 @@
+# The following commands provides an example on how to coduct training with different FL/pFL algorithms.
+# These commands assumes the dataset (in this case, cifar10 with 25 clients) has already been generated.
+# These commands train the model for 2 global rounds (-gr flag).
+# In each round 25% of the clients will be selected (-jr flag).
+# Each selected client will train the model for 2 epochs or local steps (-ls flag)
+
+# Local
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo Local
+
+# FedAvg
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo FedAvg
+
+# FedDyn
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo FedDyn -al 0.1
+
+# pFedMe
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo pFedMe -bt 1.0 -lrp 0.01
+
+# FedFomo
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo FedFomo
+
+# APFL
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo APFL -al 0.5
+
+# FedRep
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo FedRep -pls 1
+
+# LGFedAvg
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo LGFedAvg
+
+# FedPer
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo FedPer
+
+# Per-FedAvg
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo PerAvg -al 0.005 -bt 0.005
+
+# FedRoD
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo FedRoD
+
+# FedBABU
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo FedBABU -al 0.001 -bt 0.01
+
+# PGFed
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo PGFed -mu 0.1 -lam 0.01 -bt 0.0
+
+# PGFedMo
+python main.py -data cifar10 -nc 25 -jr 0.25 -gr 2 -ls 2 -algo PGFed -mu 0.1 -lam 0.01 -bt 0.5
+

+ 92 - 0
system/utils/data_utils.py

@@ -0,0 +1,92 @@
+import ujson
+import numpy as np
+import os
+import torch
+
+# IMAGE_SIZE = 28
+# IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
+# NUM_CHANNELS = 1
+
+# IMAGE_SIZE_CIFAR = 32
+# NUM_CHANNELS_CIFAR = 3
+
+
+def batch_data(data, batch_size):
+    '''
+    data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client)
+    returns x, y, which are both numpy array of length: batch_size
+    '''
+    data_x = data['x']
+    data_y = data['y']
+
+    # randomly shuffle data
+    ran_state = np.random.get_state()
+    np.random.shuffle(data_x)
+    np.random.set_state(ran_state)
+    np.random.shuffle(data_y)
+
+    # loop through mini-batches
+    for i in range(0, len(data_x), batch_size):
+        batched_x = data_x[i:i+batch_size]
+        batched_y = data_y[i:i+batch_size]
+        yield (batched_x, batched_y)
+
+
+def get_random_batch_sample(data_x, data_y, batch_size):
+    num_parts = len(data_x)//batch_size + 1
+    if(len(data_x) > batch_size):
+        batch_idx = np.random.choice(list(range(num_parts + 1)))
+        sample_index = batch_idx*batch_size
+        if(sample_index + batch_size > len(data_x)):
+            return (data_x[sample_index:], data_y[sample_index:])
+        else:
+            return (data_x[sample_index: sample_index+batch_size], data_y[sample_index: sample_index+batch_size])
+    else:
+        return (data_x, data_y)
+
+
+def get_batch_sample(data, batch_size):
+    data_x = data['x']
+    data_y = data['y']
+
+    # np.random.seed(100)
+    ran_state = np.random.get_state()
+    np.random.shuffle(data_x)
+    np.random.set_state(ran_state)
+    np.random.shuffle(data_y)
+
+    batched_x = data_x[0:batch_size]
+    batched_y = data_y[0:batch_size]
+    return (batched_x, batched_y)
+
+
+def read_data(dataset, idx, is_train=True):
+    if is_train:
+        train_data_dir = os.path.join('../dataset', dataset, 'train/')
+        train_file = train_data_dir + str(idx) + '.npz'
+        with open(train_file, 'rb') as f:
+            train_data = np.load(f, allow_pickle=True)['data'].tolist()
+        return train_data
+
+    else:
+        test_data_dir = os.path.join('../dataset', dataset, 'test/')
+        test_file = test_data_dir + str(idx) + '.npz'
+        with open(test_file, 'rb') as f:
+            test_data = np.load(f, allow_pickle=True)['data'].tolist()
+        return test_data
+
+
+def read_client_data(dataset, idx, is_train=True):
+    if is_train:
+        train_data = read_data(dataset, idx, is_train)
+        X_train = torch.Tensor(train_data['x']).type(torch.float32)
+        y_train = torch.Tensor(train_data['y']).type(torch.int64)
+
+        train_data = [(x, y) for x, y in zip(X_train, y_train)]
+        return train_data
+    else:
+        test_data = read_data(dataset, idx, is_train)
+        X_test = torch.Tensor(test_data['x']).type(torch.float32)
+        y_test = torch.Tensor(test_data['y']).type(torch.int64)
+        test_data = [(x, y) for x, y in zip(X_test, y_test)]
+        return test_data

+ 22 - 0
system/utils/tensor_utils.py

@@ -0,0 +1,22 @@
+from pkg_resources import require
+import torch
+
+def l2_squared_diff(w1, w2, requires_grad=True):
+    """ Return the sum of squared difference between two models. """
+    diff = 0.0
+    for p1, p2 in zip(w1.parameters(), w2.parameters()):
+        if requires_grad:
+            diff += torch.sum(torch.pow(p1-p2, 2))
+        else:
+            diff += torch.sum(torch.pow(p1.data-p2.data, 2))
+    return diff
+
+def model_dot_product(w1, w2, requires_grad=True):
+    """ Return the sum of squared difference between two models. """
+    dot_product = 0.0
+    for p1, p2 in zip(w1.parameters(), w2.parameters()):
+        if requires_grad:
+            dot_product += torch.sum(p1 * p2)
+        else:
+            dot_product += torch.sum(p1.data * p2.data)
+    return dot_product