Quellcode durchsuchen

[Feature]: Federated self-supervised learning (#5)

Zhuang Weiming vor 2 Jahren
Ursprung
Commit
90acf8fb09

+ 3 - 0
.gitignore

@@ -7,5 +7,8 @@ __pycache__
 *.xlsx
 *.egg-info
 docs/build
+dist/
+data/
+saved_models/
 
 

+ 2 - 1
README.md

@@ -58,7 +58,8 @@ For more advanced usage, we provide a list of tutorials on:
 
 We have released the source code for the following papers under the `applications` folder:
 
-- FedReID: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedreid) for "Performance Optimization for Federated Person Re-identification via Benchmark Analysis", _ACMMM'2020_.
+- FedReID: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedreid) for [Performance Optimization for Federated Person Re-identification via Benchmark Analysis](https://dl.acm.org/doi/10.1145/3394171.3413814) (_ACMMM'2020_).
+- FedSSL: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedssl) for two papers: [Divergence-aware Federated Self-Supervised Learning](https://openreview.net/forum?id=oVE1z8NlNe) (_ICLR'2022_)  and [Collaborative Unsupervised Visual Representation Learning From Decentralized Data](https://openaccess.thecvf.com/content/ICCV2021/html/Zhuang_Collaborative_Unsupervised_Visual_Representation_Learning_From_Decentralized_Data_ICCV_2021_paper.html) (_ICCV'2021_)
 
 The following publications are developed using EasyFL:
 

+ 115 - 0
applications/fedssl/README.md

@@ -0,0 +1,115 @@
+# Federated Self-supervised Learning (FedSSL)
+> Also name as Federated Unsupervised Representation Learning (FedU)
+
+A common limitation of existing federated learning (FL) methods is that they heavily rely on data labels on decentralized clients. We propose federated self-supervised learning framework (FedSSL) to learn visual representations from decentralized data without labels. 
+
+This repository is the code for two papers:
+- Divergence-aware Federated Self-Supervised Learning, _ICLR'2022_. [[paper]](https://openreview.net/forum?id=oVE1z8NlNe)
+- Collaborative Unsupervised Visual Representation Learning From Decentralized Data, _ICCV'2021_. [[paper]](https://openaccess.thecvf.com/content/ICCV2021/html/Zhuang_Collaborative_Unsupervised_Visual_Representation_Learning_From_Decentralized_Data_ICCV_2021_paper.html)
+
+<img src="images/fedssl.png" width="700">
+
+The framework implements four self-supervised learning (SSL) methods based on Siamese networks in the federated manner:
+1. BYOL
+2. SimSiam
+3. MoCo (MoCoV1 & MoCoV2)
+4. SimCLR
+
+## Training
+
+You can conduct training using different FedSSL methods and our proposed FedEMA method. 
+
+> You need to save the global model for further evaluation. 
+
+### FedEMA
+
+Run FedEMA with auto scaler $\tau=0.7$
+```shell
+python applications/fedssl/main.py --task_id fedema --model byol \
+      --aggregate_encoder online --update_encoder dynamic_ema_online --update_predictor dynamic_dapu \
+      --auto_scaler y --auto_scaler_target 0.7 2>&1 | tee log/${task_id}.log
+```
+
+Run FedEMA with constant weight scaler $\lambda=1$:
+```shell
+python applications/fedssl/main.py --task_id fedema --model byol \
+      --aggregate_encoder online --update_encoder dynamic_ema_online --update_predictor dynamic_dapu \
+      --weight_scaler 1 2>&1 | tee log/${task_id}.log
+```
+
+### Other SSL methods
+Run other FedSSL methods: 
+```shell
+python applications/fedssl/main.py --task_id fedbyol --model byol  \
+      --aggregate_encoder online --update_encoder online --update_predictor global
+```
+Replace `byol` in `--model byol` with other ssl methods, including `simclr`, `simsiam`, `moco`, `moco_v2` 
+
+## Evaluation
+
+You can evaluate the saved model with either linear evaluation and semi-supervised evaluation.
+
+### Linear Evaluation
+```shell
+python applications/fedssl/linear_evaluation.py --dataset cifar10 \
+      --model byol --encoder_network resnet18 \
+      --model_path <path to the saved model with postfix '.pth'> \
+      2>&1 | tee log/linear_evaluation.log
+```
+
+### Semi-supervised Evaluation
+```shell
+python applications/fedssl/semi_supervised_evaluation.py --dataset cifar10 \
+      --model byol --encoder_network resnet18 \
+      --model_path <path to the saved model with postfix '.pth'> \
+      --label_ratio 0.1 --use_MLP 
+      2>&1 | tee log/semi_supervised_evaluation.log
+```
+
+## File Structure
+```
+├── client.py <client implementation of federated learning>
+├── communication.py <constants for model update>
+├── dataset.py <dataset for semi-supervised learning>
+├── eval_dataset <dataset preprocessing for evaluation>
+├── knn_monitor.py <kNN monitoring>
+├── main.py <file for start running>
+├── model.py <ssl models>
+├── resnet.py <network architectures used>
+├── server.py <server implementation of federated learning>
+├── transform.py <image transformations>
+├── linear_evaluation.py <linear evaluation of models after training>
+├── semi_supervised_evaluation.py <semi-supervised evaluation of models after training>
+├── transform.py <image transformations>
+└── utils.py 
+```
+
+## Citation
+
+If you use these codes in your research, please cite these projects.
+
+```
+@inproceedings{zhuang2022fedema,
+  title={Divergence-aware Federated Self-Supervised Learning},
+  author={Weiming Zhuang and Yonggang Wen and Shuai Zhang},
+  booktitle={International Conference on Learning Representations},
+  year={2022},
+  url={https://openreview.net/forum?id=oVE1z8NlNe}
+}
+
+@inproceedings{zhuang2021fedu,
+  title={Collaborative Unsupervised Visual Representation Learning from Decentralized Data},
+  author={Zhuang, Weiming and Gan, Xin and Wen, Yonggang and Zhang, Shuai and Yi, Shuai},
+  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+  pages={4912--4921},
+  year={2021}
+}
+
+@article{zhuang2022easyfl,
+  title={Easyfl: A low-code federated learning platform for dummies},
+  author={Zhuang, Weiming and Gan, Xin and Wen, Yonggang and Zhang, Shuai},
+  journal={IEEE Internet of Things Journal},
+  year={2022},
+  publisher={IEEE}
+}
+```

+ 0 - 0
applications/fedssl/__init__.py


+ 366 - 0
applications/fedssl/client.py

@@ -0,0 +1,366 @@
+import copy
+import gc
+import logging
+import time
+from collections import Counter
+
+import numpy as np
+import torch
+import torch._utils
+import torch.nn as nn
+import torch.nn.functional as F
+
+import model
+import utils
+from communication import ONLINE, TARGET, BOTH, LOCAL, GLOBAL, DAPU, NONE, EMA, DYNAMIC_DAPU, DYNAMIC_EMA_ONLINE, SELECTIVE_EMA
+from easyfl.client.base import BaseClient
+from easyfl.distributed.distributed import CPU
+
+logger = logging.getLogger(__name__)
+
+L2 = "l2"
+
+
+class FedSSLClient(BaseClient):
+    def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
+        super(FedSSLClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
+        self._local_model = None
+        self.DAPU_predictor = LOCAL
+        self.encoder_distance = 1
+        self.encoder_distances = []
+        self.previous_trained_round = -1
+        self.weight_scaler = None
+
+    def decompression(self):
+        if self.model is None:
+            # Initialization at beginning of the task
+            self.model = self.compressed_model
+
+        self.update_model()
+
+    def update_model(self):
+        if self.conf.model in [model.MoCo, model.MoCoV2]:
+            self.model.encoder_q = self.compressed_model.encoder_q
+            # self.model.encoder_k = copy.deepcopy(self._local_model.encoder_k)
+        elif self.conf.model == model.SimCLR:
+            self.model.online_encoder = self.compressed_model.online_encoder
+        elif self.conf.model in [model.SimSiam, model.SimSiamNoSG]:
+            if self._local_model is None:
+                self.model.online_encoder = self.compressed_model.online_encoder
+                self.model.online_predictor = self.compressed_model.online_predictor
+                return
+
+            if self.conf.update_encoder == ONLINE:
+                online_encoder = self.compressed_model.online_encoder
+            else:
+                raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
+                                 f"update {self.conf.update_encoder} is not supported")
+
+            if self.conf.update_predictor == GLOBAL:
+                predictor = self.compressed_model.online_predictor
+            else:
+                raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
+
+            self.model.online_encoder = copy.deepcopy(online_encoder)
+            self.model.online_predictor = copy.deepcopy(predictor)
+
+        elif self.conf.model in [model.Symmetric, model.SymmetricNoSG]:
+            self.model.online_encoder = self.compressed_model.online_encoder
+
+        elif self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
+
+            if self._local_model is None:
+                logger.info("Use aggregated encoder and predictor")
+                self.model.online_encoder = self.compressed_model.online_encoder
+                self.model.target_encoder = self.compressed_model.online_encoder
+                self.model.online_predictor = self.compressed_model.online_predictor
+                return
+
+            def ema_online():
+                self._calculate_weight_scaler()
+                logger.info(f"Encoder: update online with EMA of global encoder @ round {self.conf.round_id}")
+                weight = self.encoder_distance
+                weight = min(1, self.weight_scaler * weight)
+                weight = 1 - weight
+                self.compressed_model = self.compressed_model.cpu()
+                online_encoder = self.compressed_model.online_encoder
+                target_encoder = self._local_model.target_encoder
+                ema_updater = model.EMA(weight)
+                model.update_moving_average(ema_updater, online_encoder, self._local_model.online_encoder)
+                return online_encoder, target_encoder
+
+            def ema_predictor():
+                logger.info(f"Predictor: use dynamic DAPU")
+                distance = self.encoder_distance
+                distance = min(1, distance * self.weight_scaler)
+                if distance > 0.5:
+                    weight = distance
+                    ema_updater = model.EMA(weight)
+                    predictor = self._local_model.online_predictor
+                    model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
+                else:
+                    weight = 1 - distance
+                    ema_updater = model.EMA(weight)
+                    predictor = self.compressed_model.online_predictor
+                    model.update_moving_average(ema_updater, predictor, self._local_model.online_predictor)
+                return predictor
+
+            if self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == ONLINE:
+                logger.info("Encoder: aggregate online, update online")
+                online_encoder = self.compressed_model.online_encoder
+                target_encoder = self._local_model.target_encoder
+            elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == ONLINE:
+                logger.info("Encoder: aggregate target, update online")
+                online_encoder = self.compressed_model.target_encoder
+                target_encoder = self._local_model.target_encoder
+            elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == TARGET:
+                logger.info("Encoder: aggregate target, update target")
+                online_encoder = self._local_model.online_encoder
+                target_encoder = self.compressed_model.target_encoder
+            elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == TARGET:
+                logger.info("Encoder: aggregate online, update target")
+                online_encoder = self._local_model.online_encoder
+                target_encoder = self.compressed_model.online_encoder
+            elif self.conf.aggregate_encoder == ONLINE and self.conf.update_encoder == BOTH:
+                logger.info("Encoder: aggregate online, update both")
+                online_encoder = self.compressed_model.online_encoder
+                target_encoder = self.compressed_model.online_encoder
+            elif self.conf.aggregate_encoder == TARGET and self.conf.update_encoder == BOTH:
+                logger.info("Encoder: aggregate target, update both")
+                online_encoder = self.compressed_model.target_encoder
+                target_encoder = self.compressed_model.target_encoder
+            elif self.conf.update_encoder == NONE:
+                logger.info("Encoder: use local online and target encoders")
+                online_encoder = self._local_model.online_encoder
+                target_encoder = self._local_model.target_encoder
+            elif self.conf.update_encoder == EMA:
+                logger.info(f"Encoder: use EMA, weight {self.conf.encoder_weight}")
+                online_encoder = self._local_model.online_encoder
+                ema_updater = model.EMA(self.conf.encoder_weight)
+                model.update_moving_average(ema_updater, online_encoder, self.compressed_model.online_encoder)
+                target_encoder = self._local_model.target_encoder
+            elif self.conf.update_encoder == DYNAMIC_EMA_ONLINE:
+                # Use FedEMA to update online encoder
+                online_encoder, target_encoder = ema_online()
+            elif self.conf.update_encoder == SELECTIVE_EMA:
+                # Use FedEMA to update online encoder
+                # For random selection, only update with EMA when the client is selected in previous round.
+                if self.previous_trained_round + 1 == self.conf.round_id:
+                    online_encoder, target_encoder = ema_online()
+                else:
+                    logger.info(f"Encoder: update online and target @ round {self.conf.round_id}")
+                    online_encoder = self.compressed_model.online_encoder
+                    target_encoder = self.compressed_model.online_encoder
+            else:
+                raise ValueError(f"Encoder: aggregate {self.conf.aggregate_encoder}, "
+                                 f"update {self.conf.update_encoder} is not supported")
+
+            if self.conf.update_predictor == GLOBAL:
+                logger.info("Predictor: use global predictor")
+                predictor = self.compressed_model.online_predictor
+            elif self.conf.update_predictor == LOCAL:
+                logger.info("Predictor: use local predictor")
+                predictor = self._local_model.online_predictor
+            elif self.conf.update_predictor == DAPU:
+                # Divergence-aware predictor update (DAPU)
+                logger.info(f"Predictor: use DAPU, mu {self.conf.dapu_threshold}")
+                if self.DAPU_predictor == GLOBAL:
+                    predictor = self.compressed_model.online_predictor
+                elif self.DAPU_predictor == LOCAL:
+                    predictor = self._local_model.online_predictor
+                else:
+                    raise ValueError(f"Predictor: DAPU predictor can either use local or global predictor")
+            elif self.conf.update_predictor == DYNAMIC_DAPU:
+                # Use FedEMA to update predictor
+                predictor = ema_predictor()
+            elif self.conf.update_predictor == SELECTIVE_EMA:
+                # For random selection, only update with EMA when the client is selected in previous round.
+                if self.previous_trained_round + 1 == self.conf.round_id:
+                    predictor = ema_predictor()
+                else:
+                    logger.info("Predictor: use global predictor")
+                    predictor = self.compressed_model.online_predictor
+            elif self.conf.update_predictor == EMA:
+                logger.info(f"Predictor: use EMA, weight {self.conf.predictor_weight}")
+                predictor = self._local_model.online_predictor
+                ema_updater = model.EMA(self.conf.predictor_weight)
+                model.update_moving_average(ema_updater, predictor, self.compressed_model.online_predictor)
+            else:
+                raise ValueError(f"Predictor: {self.conf.update_predictor} is not supported")
+
+            self.model.online_encoder = copy.deepcopy(online_encoder)
+            self.model.target_encoder = copy.deepcopy(target_encoder)
+            self.model.online_predictor = copy.deepcopy(predictor)
+
+    def train(self, conf, device=CPU):
+        start_time = time.time()
+        loss_fn, optimizer = self.pretrain_setup(conf, device)
+        if conf.model in [model.MoCo, model.MoCoV2]:
+            self.model.reset_key_encoder()
+        self.train_loss = []
+        self.model.to(device)
+        old_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
+        for i in range(conf.local_epoch):
+            batch_loss = []
+            for (batched_x1, batched_x2), _ in self.train_loader:
+                x1, x2 = batched_x1.to(device), batched_x2.to(device)
+                optimizer.zero_grad()
+
+                if conf.model in [model.MoCo, model.MoCoV2]:
+                    loss = self.model(x1, x2, device)
+                elif conf.model == model.SimCLR:
+                    images = torch.cat((x1, x2), dim=0)
+                    features = self.model(images)
+                    logits, labels = self.info_nce_loss(features)
+                    loss = loss_fn(logits, labels)
+                else:
+                    loss = self.model(x1, x2)
+
+                loss.backward()
+                optimizer.step()
+                batch_loss.append(loss.item())
+
+                if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor] and conf.momentum_update:
+                    self.model.update_moving_average()
+
+            current_epoch_loss = sum(batch_loss) / len(batch_loss)
+            self.train_loss.append(float(current_epoch_loss))
+        self.train_time = time.time() - start_time
+
+        # store trained model locally
+        self._local_model = copy.deepcopy(self.model).cpu()
+        self.previous_trained_round = conf.round_id
+        if conf.update_predictor in [DAPU, DYNAMIC_DAPU, SELECTIVE_EMA] or conf.update_encoder in [DYNAMIC_EMA_ONLINE, SELECTIVE_EMA]:
+            new_model = copy.deepcopy(nn.Sequential(*list(self.model.children())[:-1])).cpu()
+            self.encoder_distance = self._calculate_divergence(old_model, new_model)
+            self.encoder_distances.append(self.encoder_distance.item())
+            self.DAPU_predictor = self._DAPU_predictor_usage(self.encoder_distance)
+            if self.conf.auto_scaler == 'y' and self.conf.random_selection:
+                self._calculate_weight_scaler()
+            if (conf.round_id + 1) % 100 == 0:
+                logger.info(f"Client {self.cid}, encoder distances: {self.encoder_distances}")
+
+    def _DAPU_predictor_usage(self, distance):
+        if distance < self.conf.dapu_threshold:
+            return GLOBAL
+        else:
+            return LOCAL
+
+    def _calculate_divergence(self, old_model, new_model, typ=L2):
+        size = 0
+        total_distance = 0
+        old_dict = old_model.state_dict()
+        new_dict = new_model.state_dict()
+        for name, param in old_model.named_parameters():
+            if 'conv' in name and 'weight' in name:
+                total_distance += self._calculate_distance(old_dict[name].detach().clone().view(1, -1),
+                                                           new_dict[name].detach().clone().view(1, -1),
+                                                           typ)
+                size += 1
+        distance = total_distance / size
+        logger.info(f"Model distance: {distance} = {total_distance}/{size}")
+        return distance
+
+    def _calculate_distance(self, m1, m2, typ=L2):
+        if typ == L2:
+            return torch.dist(m1, m2, 2)
+
+    def _calculate_weight_scaler(self):
+        if not self.weight_scaler:
+            if self.conf.auto_scaler == 'y':
+                self.weight_scaler = self.conf.auto_scaler_target / self.encoder_distance
+            else:
+                self.weight_scaler = self.conf.weight_scaler
+            logger.info(f"Client {self.cid}: weight scaler {self.weight_scaler}")
+
+    def load_loader(self, conf):
+        drop_last = conf.drop_last
+        train_loader = self.train_data.loader(conf.batch_size,
+                                              self.cid,
+                                              shuffle=True,
+                                              drop_last=drop_last,
+                                              seed=conf.seed,
+                                              transform=self._load_transform(conf))
+        _print_label_count(self.cid, self.train_data.data[self.cid]['y'])
+        return train_loader
+
+    def load_optimizer(self, conf):
+        lr = conf.optimizer.lr
+        if conf.optimizer.lr_type == "cosine":
+            lr = compute_lr(conf.round_id, conf.rounds, 0, conf.optimizer.lr)
+
+        # movo_v1 should use the default learning rate
+        if conf.model == model.MoCo:
+            lr = conf.optimizer.lr
+
+        params = self.model.parameters()
+        if conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
+            params = [
+                {'params': self.model.online_encoder.parameters()},
+                {'params': self.model.online_predictor.parameters()}
+            ]
+
+        if conf.optimizer.type == "Adam":
+            optimizer = torch.optim.Adam(params, lr=lr)
+        else:
+            optimizer = torch.optim.SGD(params,
+                                        lr=lr,
+                                        momentum=conf.optimizer.momentum,
+                                        weight_decay=conf.optimizer.weight_decay)
+        return optimizer
+
+    def _load_transform(self, conf):
+        transformation = utils.get_transformation(conf.model)
+        return transformation(conf.image_size, conf.gaussian)
+
+    def post_upload(self):
+        if self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
+            del self.model
+            del self.compressed_model
+            self.model = None
+            self.compressed_model = None
+            assert self.model is None
+            assert self.compressed_model is None
+            gc.collect()
+            torch.cuda.empty_cache()
+
+    def info_nce_loss(self, features, n_views=2, temperature=0.07):
+        labels = torch.cat([torch.arange(self.conf.batch_size) for i in range(n_views)], dim=0)
+        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
+        labels = labels.to(self.device)
+
+        features = F.normalize(features, dim=1)
+
+        similarity_matrix = torch.matmul(features, features.T)
+        # assert similarity_matrix.shape == (
+        #     n_views * self.conf.batch_size, n_views * self.conf.batch_size)
+        # assert similarity_matrix.shape == labels.shape
+
+        # discard the main diagonal from both: labels and similarities matrix
+        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
+        labels = labels[~mask].view(labels.shape[0], -1)
+        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
+        # assert similarity_matrix.shape == labels.shape
+
+        # select and combine multiple positives
+        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
+
+        # select only the negatives the negatives
+        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
+
+        logits = torch.cat([positives, negatives], dim=1)
+        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
+
+        logits = logits / temperature
+        return logits, labels
+
+
+def compute_lr(current_round, rounds=800, eta_min=0, eta_max=0.3):
+    """Compute learning rate as cosine decay"""
+    pi = np.pi
+    eta_t = eta_min + 0.5 * (eta_max - eta_min) * (np.cos(pi * current_round / rounds) + 1)
+    return eta_t
+
+
+def _print_label_count(cid, labels):
+    logger.info(f"client {cid}: {Counter(labels)}")

+ 11 - 0
applications/fedssl/communication.py

@@ -0,0 +1,11 @@
+ONLINE = "online"
+TARGET = "target"
+BOTH = "both"
+NONE = "none"
+LOCAL = "local"
+GLOBAL = "global"
+DAPU = "dapu"
+DYNAMIC_DAPU = "dynamic_dapu"
+EMA = "ema"
+DYNAMIC_EMA_ONLINE = "dynamic_ema_online"
+SELECTIVE_EMA = "selective_ema"

+ 98 - 0
applications/fedssl/dataset.py

@@ -0,0 +1,98 @@
+import logging
+import os
+
+import torchvision
+import torchvision.transforms as transforms
+
+from easyfl.datasets import FederatedTensorDataset
+from easyfl.datasets.data import CIFAR100
+from easyfl.datasets.simulation import data_simulation
+from easyfl.datasets.utils.util import save_dict, load_dict
+from utils import get_transformation
+
+logger = logging.getLogger(__name__)
+
+
+def semi_supervised_preprocess(dataset, num_of_client, split_type, weights, alpha, min_size, class_per_client,
+                               label_ratio=0.01):
+    setting = f"{dataset}_{split_type}_{num_of_client}_{min_size}_{class_per_client}_{alpha}_{0}_{label_ratio}"
+    data_path = f"./data/{dataset}"
+    data_folder = os.path.join(data_path, setting)
+    if not os.path.exists(data_folder):
+        os.makedirs(data_folder)
+    train_path = os.path.join(data_folder, "train")
+    test_path = os.path.join(data_folder, "test")
+    labeled_path = os.path.join(data_folder, "labeled")
+
+    if os.path.exists(train_path):
+        print("Load existing data")
+        return load_dict(train_path), load_dict(test_path), load_dict(labeled_path)
+
+    if dataset == CIFAR100:
+        train_set = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True)
+        test_set = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True)
+    else:
+        train_set = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True)
+        test_set = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True)
+    train_size = len(train_set.data)
+    label_size = int(train_size * label_ratio)
+    labeled_data = {
+        'x': train_set.data[:label_size],
+        'y': train_set.targets[:label_size],
+    }
+    train_data = {
+        'x': train_set.data[label_size:],
+        'y': train_set.targets[label_size:],
+    }
+    test_data = {
+        'x': test_set.data,
+        'y': test_set.targets,
+    }
+    print(f"{dataset} data simulation begins")
+    _, train_data = data_simulation(train_data['x'],
+                                    train_data['y'],
+                                    num_of_client,
+                                    split_type,
+                                    weights,
+                                    alpha,
+                                    min_size,
+                                    class_per_client)
+    print(f"{dataset} data simulation is done")
+
+    save_dict(train_data, train_path)
+    save_dict(test_data, test_path)
+    save_dict(labeled_data, labeled_path)
+
+    return train_data, test_data, labeled_data
+
+
+def get_semi_supervised_dataset(dataset, num_of_client, split_type, class_per_client, label_ratio=0.01, image_size=32,
+                                gaussian=False):
+    train_data, test_data, labeled_data = semi_supervised_preprocess(dataset, num_of_client, split_type, None, 0.5, 10,
+                                                                     class_per_client, label_ratio)
+
+    fine_tune_transform = transforms.Compose([
+        torchvision.transforms.ToPILImage(mode='RGB'),
+        torchvision.transforms.Resize(size=image_size),
+        torchvision.transforms.ToTensor(),
+    ])
+
+    train_data = FederatedTensorDataset(train_data,
+                                        simulated=True,
+                                        do_simulate=False,
+                                        process_x=None,
+                                        process_y=None,
+                                        transform=get_transformation("byol")(image_size, gaussian))
+    test_data = FederatedTensorDataset(test_data,
+                                       simulated=False,
+                                       do_simulate=False,
+                                       process_x=None,
+                                       process_y=None,
+                                       transform=get_transformation("byol")(image_size, gaussian).test_transform)
+    labeled_data = FederatedTensorDataset(labeled_data,
+                                          simulated=False,
+                                          do_simulate=False,
+                                          process_x=None,
+                                          process_y=None,
+                                          transform=fine_tune_transform)
+    return train_data, test_data, labeled_data

+ 39 - 0
applications/fedssl/eval_dataset.py

@@ -0,0 +1,39 @@
+import torch
+from torchvision import datasets
+
+from dataset import get_semi_supervised_dataset
+from easyfl.datasets.data import CIFAR100
+from transform import SimCLRTransform
+
+
+def get_data_loaders(dataset, image_size=32, batch_size=512, num_workers=8):
+    transformation = SimCLRTransform(size=image_size, gaussian=False).test_transform
+
+    if dataset == CIFAR100:
+        data_path = "./data/cifar100"
+        train_dataset = datasets.CIFAR100(data_path, download=True, transform=transformation)
+        test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
+    else:
+        data_path = "./data/cifar10"
+        train_dataset = datasets.CIFAR10(data_path, download=True, transform=transformation)
+        test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)
+
+    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
+    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
+
+    return train_loader, test_loader
+
+
+def get_semi_supervised_data_loaders(dataset, data_distribution, class_per_client, label_ratio, batch_size=512, num_workers=8, image_size=32):
+    transformation = SimCLRTransform(size=image_size, gaussian=False).test_transform
+    if dataset == CIFAR100:
+        data_path = "./data/cifar100"
+        test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
+    else:
+        data_path = "./data/cifar10"
+        test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)
+
+    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
+
+    _, _, labeled_data = get_semi_supervised_dataset(dataset, 5, data_distribution, class_per_client, label_ratio)
+    return labeled_data.loader(batch_size), test_loader

BIN
applications/fedssl/images/fedssl.png


+ 64 - 0
applications/fedssl/knn_monitor.py

@@ -0,0 +1,64 @@
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+
+
+# code is obtained from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=lzFyFnhbk8hj
+# test using a knn monitor
+def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1, hide_progress=False, device=None):
+    net.eval()
+    classes = len(memory_data_loader.dataset.classes)
+    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
+    with torch.no_grad():
+        # generate feature bank
+        for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=hide_progress):
+            if device is None:
+                data = data.cuda(non_blocking=True)
+            else:
+                data = data.to(device, non_blocking=True)
+            feature = net(data)
+            feature = F.normalize(feature, dim=1)
+            feature_bank.append(feature)
+        # [D, N]
+        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
+        # [N]
+        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
+        # loop test data to predict the label by weighted knn search
+        test_bar = tqdm(test_data_loader, desc='kNN', disable=hide_progress)
+        for data, target in test_bar:
+            if device is None:
+                data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
+            else:
+                data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
+            feature = net(data)
+            feature = F.normalize(feature, dim=1)
+
+            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t)
+
+            total_num += data.size(0)
+            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
+            test_bar.set_postfix({'Accuracy': total_top1 / total_num * 100})
+        print("Accuracy: {}".format(total_top1 / total_num * 100))
+    return total_top1 / total_num * 100
+
+
+# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
+# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
+def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
+    # compute cos similarity between each feature vector and feature bank ---> [B, N]
+    sim_matrix = torch.mm(feature, feature_bank)
+    # [B, K]
+    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
+    # [B, K]
+    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
+    sim_weight = (sim_weight / knn_t).exp()
+
+    # counts for each class
+    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
+    # [B*K, C]
+    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
+    # weighted score ---> [B, C]
+    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
+
+    pred_labels = pred_scores.argsort(dim=-1, descending=True)
+    return pred_labels

+ 161 - 0
applications/fedssl/linear_evaluation.py

@@ -0,0 +1,161 @@
+import argparse
+from collections import defaultdict
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from easyfl.datasets.data import CIFAR100
+from eval_dataset import get_data_loaders
+from model import get_encoder_network
+
+
+def inference(loader, model, device):
+    feature_vector = []
+    labels_vector = []
+    model.eval()
+    for step, (x, y) in enumerate(loader):
+        x = x.to(device)
+
+        # get encoding
+        with torch.no_grad():
+            h = model(x)
+
+        h = h.squeeze()
+        h = h.detach()
+
+        feature_vector.extend(h.cpu().detach().numpy())
+        labels_vector.extend(y.numpy())
+
+        if step % 5 == 0:
+            print(f"Step [{step}/{len(loader)}]\t Computing features...")
+
+    feature_vector = np.array(feature_vector)
+    labels_vector = np.array(labels_vector)
+    print("Features shape {}".format(feature_vector.shape))
+    return feature_vector, labels_vector
+
+
+def get_features(model, train_loader, test_loader, device):
+    train_X, train_y = inference(train_loader, model, device)
+    test_X, test_y = inference(test_loader, model, device)
+    return train_X, train_y, test_X, test_y
+
+
+def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
+    train = torch.utils.data.TensorDataset(
+        torch.from_numpy(X_train), torch.from_numpy(y_train)
+    )
+    train_loader = torch.utils.data.DataLoader(
+        train, batch_size=batch_size, shuffle=False
+    )
+
+    test = torch.utils.data.TensorDataset(
+        torch.from_numpy(X_test), torch.from_numpy(y_test)
+    )
+    test_loader = torch.utils.data.DataLoader(
+        test, batch_size=batch_size, shuffle=False
+    )
+    return train_loader, test_loader
+
+
+def test_result(test_loader, logreg, device, model_path):
+    # Test fine-tuned model
+    print("### Calculating final testing performance ###")
+    logreg.eval()
+    metrics = defaultdict(list)
+    for step, (h, y) in enumerate(test_loader):
+        h = h.to(device)
+        y = y.to(device)
+
+        outputs = logreg(h)
+
+        # calculate accuracy and save metrics
+        accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
+        metrics["Accuracy/test"].append(accuracy)
+
+    print(f"Final test performance: " + model_path)
+    for k, v in metrics.items():
+        print(f"{k}: {np.array(v).mean():.4f}")
+    return np.array(metrics["Accuracy/test"]).mean()
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dataset", default="cifar10", type=str)
+    parser.add_argument("--model_path", required=True, type=str, help="Path to pre-trained model (e.g. model-10.pt)")
+    parser.add_argument('--model', default='simsiam', type=str, help='name of the network')
+    parser.add_argument("--image_size", default=32, type=int, help="Image size")
+    parser.add_argument("--learning_rate", default=3e-3, type=float, help="Initial learning rate.")
+    parser.add_argument("--batch_size", default=512, type=int, help="Batch size for training.")
+    parser.add_argument("--num_epochs", default=200, type=int, help="Number of epochs to train for.")
+    parser.add_argument("--encoder_network", default="resnet18", type=str, help="Encoder network architecture.")
+    parser.add_argument("--num_workers", default=8, type=int, help="Number of data workers (caution with nodes!)")
+    parser.add_argument("--fc", default="identity", help="options: identity, remove")
+    args = parser.parse_args()
+    print(args)
+
+    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
+
+    # get data loaders
+    train_loader, test_loader = get_data_loaders(args.dataset, args.image_size, args.batch_size, args.num_workers)
+
+    # get model
+    resnet = get_encoder_network(args.model, args.encoder_network)
+    resnet.load_state_dict(torch.load(args.model_path, map_location=device))
+    resnet = resnet.to(device)
+    num_features = list(resnet.children())[-1].in_features
+    if args.fc == "remove":
+        resnet = nn.Sequential(*list(resnet.children())[:-1])  # throw away fc layer
+    else:
+        resnet.fc = nn.Identity()
+
+    n_classes = 10
+    if args.dataset == CIFAR100:
+        n_classes = 100
+
+    # fine-tune model
+    logreg = nn.Sequential(nn.Linear(num_features, n_classes))
+    logreg = logreg.to(device)
+
+    # loss / optimizer
+    criterion = nn.CrossEntropyLoss()
+    optimizer = torch.optim.Adam(params=logreg.parameters(), lr=args.learning_rate)
+
+    # compute features (only needs to be done once, since it does not backprop during fine-tuning)
+    print("Creating features from pre-trained model")
+    (train_X, train_y, test_X, test_y) = get_features(
+        resnet, train_loader, test_loader, device
+    )
+
+    train_loader, test_loader = create_data_loaders_from_arrays(
+        train_X, train_y, test_X, test_y, 2048
+    )
+
+    # Train fine-tuned model
+    logreg.train()
+    for epoch in range(args.num_epochs):
+        metrics = defaultdict(list)
+        for step, (h, y) in enumerate(train_loader):
+            h = h.to(device)
+            y = y.to(device)
+
+            outputs = logreg(h)
+
+            loss = criterion(outputs, y)
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+            # calculate accuracy and save metrics
+            accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
+            metrics["Loss/train"].append(loss.item())
+            metrics["Accuracy/train"].append(accuracy)
+
+        print(f"Epoch [{epoch}/{args.num_epochs}]: " + "\t".join(
+            [f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))
+
+        if epoch % 100 == 0:
+            print("======epoch {}======".format(epoch))
+            test_result(test_loader, logreg, device, args.model_path)
+    test_result(test_loader, logreg, device, args.model_path)

+ 170 - 0
applications/fedssl/main.py

@@ -0,0 +1,170 @@
+import argparse
+
+import easyfl
+from client import FedSSLClient
+from dataset import get_semi_supervised_dataset
+from easyfl.datasets.data import CIFAR100
+from easyfl.distributed import slurm
+from model import get_model, BYOLNoEMA, BYOL, BYOLNoSG, BYOLNoEMA_NoSG
+from server import FedSSLServer
+
+
+def run():
+    parser = argparse.ArgumentParser(description='FedSSL')
+    parser.add_argument("--task_id", type=str, default="")
+    parser.add_argument("--dataset", type=str, default='cifar10', help='options: cifar10, cifar100')
+    parser.add_argument("--data_partition", type=str, default='class', help='options: class, iid, dir')
+    parser.add_argument("--dir_alpha", type=float, default=0.1, help='alpha for dirichlet sampling')
+    parser.add_argument('--model', default='byol', type=str, help='options: byol, simsiam, simclr, moco, moco_v2')
+    parser.add_argument('--encoder_network', default='resnet18', type=str,
+                        help='network architecture of encoder, options: resnet18, resnet50')
+    parser.add_argument('--predictor_network', default='2_layer', type=str,
+                        help='network of predictor, options: 1_layer, 2_layer')
+
+    parser.add_argument('--batch_size', default=128, type=int)
+    parser.add_argument('--local_epoch', default=5, type=int)
+    parser.add_argument('--rounds', default=100, type=int)
+    parser.add_argument('--num_of_clients', default=5, type=int)
+    parser.add_argument('--clients_per_round', default=5, type=int)
+    parser.add_argument('--class_per_client', default=2, type=int,
+                        help='for non-IID setting, number of classes each client, based on CIFAR10')
+    parser.add_argument('--optimizer_type', default='SGD', type=str, help='optimizer type')
+    parser.add_argument('--lr', default=0.032, type=float)
+    parser.add_argument('--lr_type', default='cosine', type=str, help='cosine decay learning rate')
+    parser.add_argument('--random_selection', action='store_true', help='whether randomly select clients')
+
+    parser.add_argument('--aggregate_encoder', default='online', type=str, help='options: online, target')
+    parser.add_argument('--update_encoder', default='online', type=str, help='options: online, target, both, none')
+    parser.add_argument('--update_predictor', default='global', type=str, help='options: global, local, dapu')
+    parser.add_argument('--dapu_threshold', default=0.4, type=float, help='DAPU threshold value')
+    parser.add_argument('--weight_scaler', default=1.0, type=float, help='weight scaler for different class per client')
+    parser.add_argument('--auto_scaler', default='y', type=str, help='use value to compute auto scaler')
+    parser.add_argument('--auto_scaler_target', default=0.8, type=float,
+                        help='target weight for the first time scaling')
+    parser.add_argument('--encoder_weight', type=float, default=0,
+                        help='for ema encoder update, apply on local encoder')
+    parser.add_argument('--predictor_weight', type=float, default=0,
+                        help='for ema predictor update, apply on local predictor')
+
+    parser.add_argument('--test_every', default=10, type=int, help='test every x rounds')
+    parser.add_argument('--save_model_every', default=10, type=int, help='save model every x rounds')
+    parser.add_argument('--save_predictor', action='store_true', help='whether save predictor')
+
+    parser.add_argument('--semi_supervised', action='store_true', help='whether to train with semi-supervised data')
+    parser.add_argument('--label_ratio', default=0.01, type=float, help='percentage of labeled data')
+
+    parser.add_argument('--gpu', default=0, type=int)
+    parser.add_argument('--run_count', default=0, type=int)
+
+    args = parser.parse_args()
+    print("arguments: ", args)
+
+    class_per_client = args.class_per_client
+    if args.dataset == CIFAR100:
+        class_per_client *= 10
+
+    task_id = args.task_id
+    if task_id == "":
+        task_id = f"{args.dataset}_{args.model}_{args.encoder_network}_{args.data_partition}_" \
+                  f"aggregate_{args.aggregate_encoder}_update_{args.update_encoder}_predictor_{args.update_predictor}_" \
+                  f"run{args.run_count}"
+
+    momentum_update = True
+    if args.model == BYOLNoEMA:
+        args.model = BYOL
+        momentum_update = False
+    elif args.model == BYOLNoEMA_NoSG:
+        args.model = BYOLNoSG
+        momentum_update = False
+
+    image_size = 32
+
+    config = {
+        "task_id": task_id,
+        "data": {
+            "dataset": args.dataset,
+            "num_of_clients": args.num_of_clients,
+            "split_type": args.data_partition,
+            "class_per_client": class_per_client,
+            "data_amount": 1,
+            "iid_fraction": 1,
+            "min_size": 10,
+            "alpha": args.dir_alpha,
+        },
+        "model": args.model,
+        "test_mode": "test_in_server",
+        "server": {
+            "batch_size": args.batch_size,
+            "rounds": args.rounds,
+            "test_every": args.test_every,
+            "save_model_every": args.save_model_every,
+            "clients_per_round": args.clients_per_round,
+            "random_selection": args.random_selection,
+            "save_predictor": args.save_predictor,
+            "test_all": True,
+        },
+        "client": {
+            "drop_last": True,
+            "batch_size": args.batch_size,
+            "local_epoch": args.local_epoch,
+            "optimizer": {
+                "type": args.optimizer_type,
+                "lr_type": args.lr_type,
+                "lr": args.lr,
+                "momentum": 0.9,
+                "weight_decay": 0.0005,
+            },
+            # application specific
+            "model": args.model,
+            "rounds": args.rounds,
+            "gaussian": False,
+            "image_size": image_size,
+
+            "aggregate_encoder": args.aggregate_encoder,
+            "update_encoder": args.update_encoder,
+            "update_predictor": args.update_predictor,
+            "dapu_threshold": args.dapu_threshold,
+            "weight_scaler": args.weight_scaler,
+            "auto_scaler": args.auto_scaler,
+            "auto_scaler_target": args.auto_scaler_target,
+            "random_selection": args.random_selection,
+
+            "encoder_weight": args.encoder_weight,
+            "predictor_weight": args.predictor_weight,
+
+            "momentum_update": momentum_update,
+        },
+        'resource_heterogeneous': {"grouping_strategy": ""}
+    }
+
+    if args.gpu > 1:
+        rank, local_rank, world_size, host_addr = slurm.setup()
+        distribute_config = {
+            "gpu": world_size,
+            "distributed": {
+                "rank": rank,
+                "local_rank": local_rank,
+                "world_size": world_size,
+                "init_method": host_addr
+            },
+        }
+        config.update(distribute_config)
+
+    if args.semi_supervised:
+        train_data, test_data, _ = get_semi_supervised_dataset(args.dataset,
+                                                               args.num_of_clients,
+                                                               args.data_partition,
+                                                               class_per_client,
+                                                               args.label_ratio)
+        easyfl.register_dataset(train_data, test_data)
+
+    model = get_model(args.model, args.encoder_network, args.predictor_network)
+    easyfl.register_model(model)
+    easyfl.register_client(FedSSLClient)
+    easyfl.register_server(FedSSLServer)
+    easyfl.init(config, init_all=True)
+    easyfl.run()
+
+
+if __name__ == '__main__':
+    run()

+ 483 - 0
applications/fedssl/model.py

@@ -0,0 +1,483 @@
+import copy
+
+import torch
+import torch.nn.functional as F
+import torchvision.models as models
+from torch import nn
+
+from easyfl.models.model import BaseModel
+from easyfl.models.resnet import ResNet18, ResNet50
+
+SimSiam = "simsiam"
+SimSiamNoSG = "simsiam_no_sg"
+SimCLR = "simclr"
+MoCo = "moco"
+MoCoV2 = "moco_v2"
+BYOL = "byol"
+BYOLNoSG = "byol_no_sg"
+BYOLNoEMA = "byol_no_ema"
+BYOLNoEMA_NoSG = "byol_no_ema_no_sg"
+BYOLNoPredictor = "byol_no_p"
+Symmetric = "symmetric"
+SymmetricNoSG = "symmetric_no_sg"
+
+OneLayer = "1_layer"
+TwoLayer = "2_layer"
+
+RESNET18 = "resnet18"
+RESNET50 = "resnet50"
+
+
+def get_encoder(arch=RESNET18):
+    return models.__dict__[arch]
+
+
+def get_model(model, encoder_network, predictor_network=TwoLayer):
+    mlp = False
+    T = 0.07
+    stop_gradient = True
+    has_predictor = True
+    if model == SymmetricNoSG:
+        stop_gradient = False
+        model = Symmetric
+    elif model == SimSiamNoSG:
+        stop_gradient = False
+        model = SimSiam
+    elif model == BYOLNoSG:
+        stop_gradient = False
+        model = BYOL
+    elif model == BYOLNoPredictor:
+        has_predictor = False
+        model = BYOL
+    elif model == MoCoV2:
+        model = MoCo
+        mlp = True
+        T = 0.2
+
+    if model == Symmetric:
+        if encoder_network == RESNET50:
+            return SymmetricModel(net=ResNet50(), stop_gradient=stop_gradient)
+        else:
+            return SymmetricModel(stop_gradient=stop_gradient)
+    elif model == SimSiam:
+        net = ResNet18()
+        if encoder_network == RESNET50:
+            net = ResNet50()
+        return SimSiamModel(net=net, stop_gradient=stop_gradient)
+    elif model == MoCo:
+        net = ResNet18
+        if encoder_network == RESNET50:
+            net = ResNet50
+        return MoCoModel(net=net, mlp=mlp, T=T)
+    elif model == BYOL:
+        net = ResNet18()
+        if encoder_network == RESNET50:
+            net = ResNet50()
+        return BYOLModel(net=net, stop_gradient=stop_gradient, has_predictor=has_predictor,
+                         predictor_network=predictor_network)
+    elif model == SimCLR:
+        net = ResNet18()
+        if encoder_network == RESNET50:
+            net = ResNet50()
+        return SimCLRModel(net=net)
+    else:
+        raise NotImplementedError
+
+
+def get_encoder_network(model, encoder_network, num_classes=10, projection_size=2048, projection_hidden_size=4096):
+    if model in [MoCo, MoCoV2]:
+        num_classes = 128
+
+    if encoder_network == RESNET18:
+        resnet = ResNet18(num_classes=num_classes)
+    elif encoder_network == RESNET50:
+        resnet = ResNet50(num_classes=num_classes)
+    else:
+        raise NotImplementedError
+
+    if model in [Symmetric, SimSiam, BYOL, SymmetricNoSG, SimSiamNoSG, BYOLNoSG, SimCLR]:
+        resnet.fc = MLP(resnet.feature_dim, projection_size, projection_hidden_size)
+    if model == MoCoV2:
+        resnet.fc = MLP(resnet.feature_dim, num_classes, resnet.feature_dim)
+
+    return resnet
+
+
+# ------------- SymmetricModel Model -----------------
+
+
+class SymmetricModel(BaseModel):
+    def __init__(
+            self,
+            net=ResNet18(),
+            image_size=32,
+            projection_size=2048,
+            projection_hidden_size=4096,
+            stop_gradient=True
+    ):
+        super().__init__()
+
+        self.online_encoder = net
+        self.online_encoder.fc = MLP(net.feature_dim, projection_size, projection_hidden_size)  # projector
+
+        self.stop_gradient = stop_gradient
+
+        # send a mock image tensor to instantiate singleton parameters
+        self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))
+
+    def forward(self, image_one, image_two):
+        f = self.online_encoder
+        z1, z2 = f(image_one), f(image_two)
+        if self.stop_gradient:
+            loss = D(z1, z2)
+        else:
+            loss = D_NO_SG(z1, z2)
+        return loss
+
+
+# ------------- SimSiam Model -----------------
+
+
+class SimSiamModel(BaseModel):
+    def __init__(
+            self,
+            net=ResNet18(),
+            image_size=32,
+            projection_size=2048,
+            projection_hidden_size=4096,
+            stop_gradient=True,
+    ):
+        super().__init__()
+
+        self.online_encoder = net
+        self.online_encoder.fc = MLP(net.feature_dim, projection_size, projection_hidden_size)  # projector
+
+        self.online_predictor = MLP(
+            projection_size, projection_size, projection_hidden_size
+        )
+
+        self.stop_gradient = stop_gradient
+
+        # send a mock image tensor to instantiate singleton parameters
+        self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))
+
+    def forward(self, image_one, image_two):
+        f, h = self.online_encoder, self.online_predictor
+        z1, z2 = f(image_one), f(image_two)
+        p1, p2 = h(z1), h(z2)
+        if self.stop_gradient:
+            loss = D(p1, z2) / 2 + D(p2, z1) / 2
+        else:
+            loss = D_NO_SG(p1, z2) / 2 + D_NO_SG(p2, z1) / 2
+
+        return loss
+
+
+class MLP(nn.Module):
+    def __init__(self, dim, projection_size, hidden_size=4096, num_layer=TwoLayer):
+        super().__init__()
+        self.in_features = dim
+        if num_layer == OneLayer:
+            self.net = nn.Sequential(
+                nn.Linear(dim, projection_size),
+            )
+        elif num_layer == TwoLayer:
+            self.net = nn.Sequential(
+                nn.Linear(dim, hidden_size),
+                nn.BatchNorm1d(hidden_size),
+                nn.ReLU(inplace=True),
+                nn.Linear(hidden_size, projection_size),
+            )
+        else:
+            raise NotImplementedError(f"Not defined MLP: {num_layer}")
+
+    def forward(self, x):
+        return self.net(x)
+
+
+def D(p, z, version='simplified'):  # negative cosine similarity
+    if version == 'original':
+        z = z.detach()  # stop gradient
+        p = F.normalize(p, dim=1)  # l2-normalize
+        z = F.normalize(z, dim=1)  # l2-normalize
+        return -(p * z).sum(dim=1).mean()
+
+    elif version == 'simplified':  # same thing, much faster. Scroll down, speed test in __main__
+        return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
+    else:
+        raise Exception
+
+
+def D_NO_SG(p, z, version='simplified'):  # negative cosine similarity without stop gradient
+    if version == 'original':
+        p = F.normalize(p, dim=1)  # l2-normalize
+        z = F.normalize(z, dim=1)  # l2-normalize
+        return -(p * z).sum(dim=1).mean()
+
+    elif version == 'simplified':  # same thing, much faster. Scroll down, speed test in __main__
+        return - F.cosine_similarity(p, z, dim=-1).mean()
+    else:
+        raise Exception
+
+
+# ------------- BYOL Model -----------------
+
+
+class BYOLModel(BaseModel):
+    def __init__(
+            self,
+            net=ResNet18(),
+            image_size=32,
+            projection_size=2048,
+            projection_hidden_size=4096,
+            moving_average_decay=0.99,
+            stop_gradient=True,
+            has_predictor=True,
+            predictor_network=TwoLayer,
+    ):
+        super().__init__()
+
+        self.online_encoder = net
+        if not hasattr(net, 'feature_dim'):
+            feature_dim = list(net.children())[-1].in_features
+        else:
+            feature_dim = net.feature_dim
+        self.online_encoder.fc = MLP(feature_dim, projection_size, projection_hidden_size)  # projector
+
+        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size, predictor_network)
+        self.target_encoder = None
+        self.target_ema_updater = EMA(moving_average_decay)
+
+        self.stop_gradient = stop_gradient
+        self.has_predictor = has_predictor
+
+        # debug purpose
+        # self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))
+        # self.reset_moving_average()
+
+    def _get_target_encoder(self):
+        target_encoder = copy.deepcopy(self.online_encoder)
+        return target_encoder
+
+    def reset_moving_average(self):
+        del self.target_encoder
+        self.target_encoder = None
+
+    def update_moving_average(self):
+        assert (
+                self.target_encoder is not None
+        ), "target encoder has not been created yet"
+        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
+
+    def forward(self, image_one, image_two):
+        online_pred_one = self.online_encoder(image_one)
+        online_pred_two = self.online_encoder(image_two)
+
+        if self.has_predictor:
+            online_pred_one = self.online_predictor(online_pred_one)
+            online_pred_two = self.online_predictor(online_pred_two)
+
+        if self.stop_gradient:
+            with torch.no_grad():
+                if self.target_encoder is None:
+                    self.target_encoder = self._get_target_encoder()
+                target_proj_one = self.target_encoder(image_one)
+                target_proj_two = self.target_encoder(image_two)
+
+                target_proj_one = target_proj_one.detach()
+                target_proj_two = target_proj_two.detach()
+
+        else:
+            if self.target_encoder is None:
+                self.target_encoder = self._get_target_encoder()
+            target_proj_one = self.target_encoder(image_one)
+            target_proj_two = self.target_encoder(image_two)
+
+        loss_one = byol_loss_fn(online_pred_one, target_proj_two)
+        loss_two = byol_loss_fn(online_pred_two, target_proj_one)
+        loss = loss_one + loss_two
+
+        return loss.mean()
+
+
+class EMA:
+    def __init__(self, beta):
+        super().__init__()
+        self.beta = beta
+
+    def update_average(self, old, new):
+        if old is None:
+            return new
+        return old * self.beta + (1 - self.beta) * new
+
+
+def update_moving_average(ema_updater, ma_model, current_model):
+    for current_params, ma_params in zip(
+            current_model.parameters(), ma_model.parameters()
+    ):
+        old_weight, up_weight = ma_params.data, current_params.data
+        ma_params.data = ema_updater.update_average(old_weight, up_weight)
+
+
+def byol_loss_fn(x, y):
+    x = F.normalize(x, dim=-1, p=2)
+    y = F.normalize(y, dim=-1, p=2)
+    return 2 - 2 * (x * y).sum(dim=-1)
+
+
+# ------------- MoCo Model -----------------
+
+
+class MoCoModel(BaseModel):
+    def __init__(self, net=ResNet18, dim=128, K=4096, m=0.99, T=0.1, bn_splits=8, symmetric=True, mlp=False):
+        super().__init__()
+
+        self.K = K
+        self.m = m
+        self.T = T
+        self.symmetric = symmetric
+
+        # create the encoders
+        self.encoder_q = net(num_classes=dim)
+        self.encoder_k = net(num_classes=dim)
+
+        if mlp:
+            feature_dim = self.encoder_q.feature_dim
+            self.encoder_q.fc = MLP(feature_dim, dim, feature_dim)
+            self.encoder_k.fc = MLP(feature_dim, dim, feature_dim)
+
+        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
+            param_k.data.copy_(param_q.data)  # initialize
+            param_k.requires_grad = False  # not update by gradient
+
+        # create the queue
+        self.register_buffer("queue", torch.randn(dim, K))
+        self.queue = nn.functional.normalize(self.queue, dim=0)
+
+        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
+
+    @torch.no_grad()
+    def reset_key_encoder(self):
+        """
+        Momentum update of the key encoder
+        """
+        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
+            param_k.data.copy_(param_q.data)  # initialize
+            param_k.requires_grad = False  # not update by gradient
+
+    @torch.no_grad()
+    def _momentum_update_key_encoder(self):
+        """
+        Momentum update of the key encoder
+        """
+        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
+            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
+
+    @torch.no_grad()
+    def _dequeue_and_enqueue(self, keys):
+        batch_size = keys.shape[0]
+
+        ptr = int(self.queue_ptr)
+        assert self.K % batch_size == 0  # for simplicity
+
+        # replace the keys at ptr (dequeue and enqueue)
+        self.queue[:, ptr:ptr + batch_size] = keys.t()  # transpose
+        ptr = (ptr + batch_size) % self.K  # move pointer
+
+        self.queue_ptr[0] = ptr
+
+    @torch.no_grad()
+    def _batch_shuffle_single_gpu(self, x, device):
+        """
+        Batch shuffle, for making use of BatchNorm.
+        """
+        # random shuffle index
+        idx_shuffle = torch.randperm(x.shape[0]).to(device)
+
+        # index for restoring
+        idx_unshuffle = torch.argsort(idx_shuffle)
+
+        return x[idx_shuffle], idx_unshuffle
+
+    @torch.no_grad()
+    def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
+        """
+        Undo batch shuffle.
+        """
+        return x[idx_unshuffle]
+
+    def contrastive_loss(self, im_q, im_k, device):
+        # compute query features
+        q = self.encoder_q(im_q)  # queries: NxC
+        q = nn.functional.normalize(q, dim=1)  # already normalized
+
+        # compute key features
+        with torch.no_grad():  # no gradient to keys
+            # shuffle for making use of BN
+            im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k, device)
+
+            k = self.encoder_k(im_k_)  # keys: NxC
+            k = nn.functional.normalize(k, dim=1)  # already normalized
+
+            # undo shuffle
+            k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)
+
+        # compute logits
+        # Einstein sum is more intuitive
+        # positive logits: Nx1
+        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
+        # negative logits: NxK
+        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
+
+        # logits: Nx(1+K)
+        logits = torch.cat([l_pos, l_neg], dim=1)
+
+        # apply temperature
+        logits /= self.T
+
+        # labels: positive key indicators
+        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
+
+        loss = nn.CrossEntropyLoss().to(device)(logits, labels)
+
+        return loss, q, k
+
+    def forward(self, im1, im2, device):
+        """
+        Input:
+            im_q: a batch of query images
+            im_k: a batch of key images
+        Output:
+            loss
+        """
+
+        # update the key encoder
+        with torch.no_grad():  # no gradient to keys
+            self._momentum_update_key_encoder()
+
+        # compute loss
+        if self.symmetric:  # asymmetric loss
+            loss_12, q1, k2 = self.contrastive_loss(im1, im2, device)
+            loss_21, q2, k1 = self.contrastive_loss(im2, im1, device)
+            loss = loss_12 + loss_21
+            k = torch.cat([k1, k2], dim=0)
+        else:  # asymmetric loss
+            loss, q, k = self.contrastive_loss(im1, im2, device)
+
+        self._dequeue_and_enqueue(k)
+
+        return loss
+
+
+# ------------- SimCLR Model -----------------
+
+
+class SimCLRModel(BaseModel):
+    def __init__(self, net=ResNet18(), image_size=32, projection_size=2048, projection_hidden_size=4096):
+        super().__init__()
+
+        self.online_encoder = net
+        self.online_encoder.fc = MLP(net.feature_dim, projection_size, projection_hidden_size)  # projector
+
+    def forward(self, image):
+        return self.online_encoder(image)

+ 175 - 0
applications/fedssl/semi_supervised_evaluation.py

@@ -0,0 +1,175 @@
+import argparse
+from collections import defaultdict
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from easyfl.datasets.data import CIFAR100
+from eval_dataset import get_semi_supervised_data_loaders
+from model import get_encoder_network
+
+
+def test_whole(resnet, logreg, device, test_loader, model_path):
+    print("### Calculating final testing performance ###")
+    resnet.eval()
+    logreg.eval()
+    metrics = defaultdict(list)
+    for step, (h, y) in enumerate(test_loader):
+        h = h.to(device)
+        y = y.to(device)
+        with torch.no_grad():
+            outputs = logreg(resnet(h))
+
+        # calculate accuracy and save metrics
+        accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
+        metrics["Accuracy/test"].append(accuracy)
+
+    print(f"Final test performance: " + "\t".join([f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))
+    return np.array(metrics["Accuracy/test"]).mean()
+
+
+def finetune_internal(model, epochs, label_loader, test_loader, num_class, device, lr=3e-3):
+    model = model.to(device)
+    num_features = model.feature_dim
+
+    n_classes = num_class  # e.g. CIFAR-10 has 10 classes
+
+    # fine-tune model
+    logreg = nn.Sequential(nn.Linear(num_features, n_classes))
+    logreg = logreg.to(device)
+
+    # loss / optimizer
+    criterion = nn.CrossEntropyLoss()
+    optimizer = torch.optim.Adam(params=logreg.parameters(), lr=lr)
+
+    # Train fine-tuned model
+    model.train()
+    logreg.train()
+    for epoch in range(epochs):
+        metrics = defaultdict(list)
+        for step, (h, y) in enumerate(label_loader):
+            h = h.to(device)
+            y = y.to(device)
+            outputs = logreg(model(h))
+            loss = criterion(outputs, y)
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            # calculate accuracy and save metrics
+            accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
+            metrics["Loss/train"].append(loss.item())
+            metrics["Accuracy/train"].append(accuracy)
+
+        if epoch % 100 == 0:
+            print("======epoch {}======".format(epoch))
+            test_whole(model, logreg, device, test_loader, "test_whole")
+    final_accuracy = test_whole(model, logreg, device, test_loader, "test_whole")
+    print(metrics)
+    return final_accuracy
+
+
+class MLP(nn.Module):
+    def __init__(self, dim, projection_size, hidden_size=4096):
+        super().__init__()
+        self.net = nn.Sequential(
+            nn.Linear(dim, hidden_size),
+            nn.BatchNorm1d(hidden_size),
+            nn.ReLU(inplace=True),
+            nn.Linear(hidden_size, projection_size),
+        )
+
+    def forward(self, x):
+        return self.net(x)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--dataset", default="cifar10", type=str, help="cifar10/cifar100.")
+    parser.add_argument('--model', default='simsiam', type=str, help='name of the network')
+    parser.add_argument("--encoder_network", default="resnet18", type=str, help="Encoder network architecture.")
+    parser.add_argument("--model_path", required=True, type=str, help="Path to pre-trained model (e.g. model-10.pt)")
+    parser.add_argument("--image_size", default=32, type=int, help="Image size")
+    parser.add_argument("--learning_rate", default=1e-3, type=float, help="Initial learning rate.")
+    parser.add_argument("--batch_size", default=128, type=int, help="Batch size for training.")
+    parser.add_argument("--num_epochs", default=100, type=int, help="Number of epochs to train for.")
+    parser.add_argument("--data_distribution", default="class", type=str, help="class/iid")
+    parser.add_argument("--label_ratio", default=0.01, type=float, help="ratio of labeled data for fine tune")
+    parser.add_argument('--class_per_client', default=2, type=int,
+                        help='for non-IID setting, number of class each client, based on CIFAR10')
+    parser.add_argument("--use_MLP", action='store_true',
+                        help="whether use MLP, if use, one hidden layer MLP, else, Linear Layer.")
+    parser.add_argument("--num_workers", default=8, type=int,
+                        help="Number of data loading workers (caution with nodes!)")
+    args = parser.parse_args()
+
+    print(args)
+
+    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
+
+    print('==> Preparing data..')
+    class_per_client = args.class_per_client
+    n_classes = 10
+    if args.dataset == CIFAR100:
+        class_per_client = 10 * class_per_client
+        n_classes = 100
+
+    train_loader, test_loader = get_semi_supervised_data_loaders(args.dataset,
+                                                                 args.data_distribution,
+                                                                 class_per_client,
+                                                                 args.label_ratio,
+                                                                 args.batch_size,
+                                                                 args.num_workers)
+
+    print('==> Building model..')
+    resnet = get_encoder_network(args.model, args.encoder_network)
+    resnet.load_state_dict(torch.load(args.model_path, map_location=device))
+    resnet = resnet.to(device)
+    num_features = list(resnet.children())[-1].in_features
+    resnet.fc = nn.Identity()
+
+    # fine-tune model
+    if args.use_MLP:
+        logreg = MLP(num_features, n_classes, 4096)
+        logreg = logreg.to(device)
+    else:
+        logreg = nn.Sequential(nn.Linear(num_features, n_classes))
+        logreg = logreg.to(device)
+
+    # loss / optimizer
+    criterion = nn.CrossEntropyLoss()
+    optimizer = torch.optim.Adam(params=logreg.parameters(), lr=args.learning_rate)
+
+    # Train fine-tuned model
+    logreg.train()
+    resnet.train()
+    accs = []
+    for epoch in range(args.num_epochs):
+        print("======epoch {}======".format(epoch))
+        metrics = defaultdict(list)
+        for step, (h, y) in enumerate(train_loader):
+            h = h.to(device)
+            y = y.to(device)
+
+            outputs = logreg(resnet(h))
+
+            loss = criterion(outputs, y)
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+
+            # calculate accuracy and save metrics
+            accuracy = (outputs.argmax(1) == y).sum().item() / y.size(0)
+            metrics["Loss/train"].append(loss.item())
+            metrics["Accuracy/train"].append(accuracy)
+
+        print(f"Epoch [{epoch}/{args.num_epochs}]: " + "\t".join(
+            [f"{k}: {np.array(v).mean()}" for k, v in metrics.items()]))
+
+        if epoch % 1 == 0:
+            acc = test_whole(resnet, logreg, device, test_loader, args.model_path)
+            if epoch <= 100:
+                accs.append(acc)
+    test_whole(resnet, logreg, device, test_loader, args.model_path)
+    print(args.model_path)
+    print(f"Best one for 100 epoch is {max(accs):.4f}")

+ 162 - 0
applications/fedssl/server.py

@@ -0,0 +1,162 @@
+import copy
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from torchvision import datasets
+
+import model
+import utils
+from communication import TARGET
+from easyfl.datasets.data import CIFAR100
+from easyfl.distributed import reduce_models
+from easyfl.distributed.distributed import CPU
+from easyfl.server import strategies
+from easyfl.server.base import BaseServer, MODEL, DATA_SIZE
+from easyfl.tracking import metric
+from knn_monitor import knn_monitor
+
+logger = logging.getLogger(__name__)
+
+
+class FedSSLServer(BaseServer):
+    def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
+        super(FedSSLServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
+        self.train_loader = None
+        self.test_loader = None
+
+    def aggregation(self):
+        if self.conf.client.auto_scaler == 'y' and self.conf.server.random_selection:
+            self._retain_weight_scaler()
+
+        uploaded_content = self.get_client_uploads()
+        models = list(uploaded_content[MODEL].values())
+        weights = list(uploaded_content[DATA_SIZE].values())
+
+        # Aggregate networks gradually with different components.
+        if self.conf.model in [model.Symmetric, model.SymmetricNoSG, model.SimSiam, model.SimSiamNoSG, model.BYOL,
+                               model.BYOLNoSG, model.BYOLNoPredictor, model.SimCLR]:
+            online_encoders = [m.online_encoder for m in models]
+            online_encoder = self._federated_averaging(online_encoders, weights)
+            self._model.online_encoder.load_state_dict(online_encoder.state_dict())
+
+        if self.conf.model in [model.SimSiam, model.SimSiamNoSG, model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
+            predictors = [m.online_predictor for m in models]
+            predictor = self._federated_averaging(predictors, weights)
+            self._model.online_predictor.load_state_dict(predictor.state_dict())
+
+        if self.conf.model in [model.BYOL, model.BYOLNoSG, model.BYOLNoPredictor]:
+            target_encoders = [m.target_encoder for m in models]
+            target_encoder = self._federated_averaging(target_encoders, weights)
+            self._model.target_encoder = copy.deepcopy(target_encoder)
+
+        if self.conf.model in [model.MoCo, model.MoCoV2]:
+            encoder_qs = [m.encoder_q for m in models]
+            encoder_q = self._federated_averaging(encoder_qs, weights)
+            self._model.encoder_q.load_state_dict(encoder_q.state_dict())
+
+            encoder_ks = [m.encoder_k for m in models]
+            encoder_k = self._federated_averaging(encoder_ks, weights)
+            self._model.encoder_k.load_state_dict(encoder_k.state_dict())
+
+    def _retain_weight_scaler(self):
+        self.client_id_to_index = {c.cid: i for i, c in enumerate(self._clients)}
+
+        client_index = self.client_id_to_index[self.grouped_clients[0].cid]
+        weight_scaler = self.grouped_clients[0].weight_scaler if self.grouped_clients[0].weight_scaler else 0
+        scaler = torch.tensor((client_index, weight_scaler)).to(self.conf.device)
+        scalers = [torch.zeros_like(scaler) for _ in self.selected_clients]
+        dist.barrier()
+        dist.all_gather(scalers, scaler)
+
+        logger.info(f"Synced scaler {scalers}")
+        for i, client in enumerate(self._clients):
+            for scaler in scalers:
+                scaler = scaler.cpu().numpy()
+                if self.client_id_to_index[client.cid] == int(scaler[0]) and not client.weight_scaler:
+                    self._clients[i].weight_scaler = scaler[1]
+
+    def _federated_averaging(self, models, weights):
+        fn_average = strategies.federated_averaging
+        fn_sum = strategies.weighted_sum
+        fn_reduce = reduce_models
+
+        if self.conf.is_distributed:
+            dist.barrier()
+            model_, sample_sum = fn_sum(models, weights)
+            fn_reduce(model_, torch.tensor(sample_sum).to(self.conf.device))
+        else:
+            model_ = fn_average(models, weights)
+        return model_
+
+    def test_in_server(self, device=CPU):
+        testing_model = self._get_testing_model()
+        testing_model.eval()
+        testing_model.to(device)
+
+        self._get_test_data()
+
+        with torch.no_grad():
+            accuracy = knn_monitor(testing_model, self.train_loader, self.test_loader)
+
+        test_results = {
+            metric.TEST_ACCURACY: float(accuracy),
+            metric.TEST_LOSS: 0,
+        }
+        return test_results
+
+    def _get_test_data(self):
+        transformation = self._load_transform()
+        if self.train_loader is None or self.test_loader is None:
+            if self.conf.data.dataset == CIFAR100:
+                data_path = "./data/cifar100"
+                train_dataset = datasets.CIFAR100(data_path, download=True, transform=transformation)
+                test_dataset = datasets.CIFAR100(data_path, train=False, download=True, transform=transformation)
+            else:
+                data_path = "./data/cifar10"
+                train_dataset = datasets.CIFAR10(data_path, download=True, transform=transformation)
+                test_dataset = datasets.CIFAR10(data_path, train=False, download=True, transform=transformation)
+
+            if self.train_loader is None:
+                self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, num_workers=8)
+
+            if self.test_loader is None:
+                self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512, num_workers=8)
+
+    def _load_transform(self):
+        transformation = utils.get_transformation(self.conf.model)
+        return transformation().test_transform
+
+    def _get_testing_model(self, net=False):
+        if self.conf.model in [model.MoCo, model.MoCoV2]:
+            testing_model = self._model.encoder_q
+        elif self.conf.model in [model.SimSiam, model.SimSiamNoSG, model.Symmetric, model.SymmetricNoSG, model.SimCLR]:
+            testing_model = self._model.online_encoder
+        else:
+            # BYOL
+            if self.conf.client.aggregate_encoder == TARGET:
+                self.print_("Use aggregated target encoder for testing")
+                testing_model = self._model.target_encoder
+            else:
+                self.print_("Use aggregated online encoder for testing")
+                testing_model = self._model.online_encoder
+        return testing_model
+
+    def save_model(self):
+        if self._do_every(self.conf.server.save_model_every, self._current_round, self.conf.server.rounds) and self.is_primary_server():
+            save_path = self.conf.server.save_model_path
+            if save_path == "":
+                save_path = os.path.join(os.getcwd(), "saved_models", self.conf.task_id)
+            os.makedirs(save_path, exist_ok=True)
+            save_path = os.path.join(save_path,
+                                     "{}_global_model_r_{}.pth".format(self.conf.task_id, self._current_round))
+
+            torch.save(self._get_testing_model().cpu().state_dict(), save_path)
+            self.print_("Encoder model saved at {}".format(save_path))
+
+            if self.conf.server.save_predictor:
+                if self.conf.model in [model.SimSiam, model.BYOL]:
+                    save_path = save_path.replace("global_model", "predictor")
+                    torch.save(self._model.online_predictor.cpu().state_dict(), save_path)
+                    self.print_("Predictor model saved at {}".format(save_path))

+ 190 - 0
applications/fedssl/transform.py

@@ -0,0 +1,190 @@
+import numpy as np
+import torch
+import torchvision
+from torch import nn
+from torchvision import transforms
+
+
+class MoCoTransform:
+    """
+    A stochastic data augmentation module that transforms any given data example randomly
+    resulting in two correlated views of the same example,
+    denoted x ̃i and x ̃j, which we consider as a positive pair.
+    """
+
+    def __init__(self, size=32, gaussian=False):
+        self.train_transform = transforms.Compose([
+            transforms.ToPILImage(mode='RGB'),
+            transforms.RandomResizedCrop(size),
+            transforms.RandomHorizontalFlip(p=0.5),
+            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
+            transforms.RandomGrayscale(p=0.2),
+            transforms.ToTensor()])
+
+        self.test_transform = transforms.Compose([
+            torchvision.transforms.Resize(size=size),
+            transforms.ToTensor()]
+        )
+
+    def __call__(self, x):
+        return self.train_transform(x), self.train_transform(x)
+
+
+class SimSiamTransform:
+    """
+    A stochastic data augmentation module that transforms any given data example randomly
+    resulting in two correlated views of the same example,
+    denoted x ̃i and x ̃j, which we consider as a positive pair.
+    """
+
+    def __init__(self, size=32, gaussian=False):
+        s = 1
+        color_jitter = torchvision.transforms.ColorJitter(
+            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
+        )
+        if gaussian:
+            self.train_transform = torchvision.transforms.Compose(
+                [
+                    torchvision.transforms.ToPILImage(mode='RGB'),
+                    torchvision.transforms.RandomResizedCrop(size=size),
+                    torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
+                    torchvision.transforms.RandomApply([color_jitter], p=0.8),
+                    torchvision.transforms.RandomGrayscale(p=0.2),
+                    GaussianBlur(kernel_size=int(0.1 * size)),
+                    torchvision.transforms.ToTensor(),
+                ]
+            )
+        else:
+            self.train_transform = torchvision.transforms.Compose(
+                [
+                    torchvision.transforms.ToPILImage(mode='RGB'),
+                    torchvision.transforms.RandomResizedCrop(size=size),
+                    torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
+                    torchvision.transforms.RandomApply([color_jitter], p=0.8),
+                    torchvision.transforms.RandomGrayscale(p=0.2),
+                    torchvision.transforms.ToTensor(),
+                ]
+            )
+
+        self.test_transform = torchvision.transforms.Compose(
+            [
+                torchvision.transforms.Resize(size=size),
+                torchvision.transforms.ToTensor(),
+            ]
+        )
+
+    def __call__(self, x):
+        return self.train_transform(x), self.train_transform(x)
+
+
+class SimCLRTransform:
+    """
+    A stochastic data augmentation module that transforms any given data example randomly
+    resulting in two correlated views of the same example,
+    denoted x ̃i and x ̃j, which we consider as a positive pair.
+    data_format is array or image
+    """
+
+    def __init__(self, size=32, gaussian=False, data_format="array"):
+        s = 1
+        color_jitter = torchvision.transforms.ColorJitter(
+            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
+        )
+        if gaussian:
+            self.train_transform = torchvision.transforms.Compose(
+                [
+                    torchvision.transforms.ToPILImage(mode='RGB'),
+                    # torchvision.transforms.Resize(size=size),
+                    torchvision.transforms.RandomResizedCrop(size=size),
+                    torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
+                    torchvision.transforms.RandomApply([color_jitter], p=0.8),
+                    torchvision.transforms.RandomGrayscale(p=0.2),
+                    GaussianBlur(kernel_size=int(0.1 * size)),
+                    # RandomApply(torchvision.transforms.GaussianBlur((3, 3), (1.0, 2.0)), p=0.2),
+                    torchvision.transforms.ToTensor(),
+                ]
+            )
+        else:
+            if data_format == "array":
+                self.train_transform = torchvision.transforms.Compose(
+                    [
+                        torchvision.transforms.ToPILImage(mode='RGB'),
+                        # torchvision.transforms.Resize(size=size),
+                        torchvision.transforms.RandomResizedCrop(size=size),
+                        torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
+                        torchvision.transforms.RandomApply([color_jitter], p=0.8),
+                        torchvision.transforms.RandomGrayscale(p=0.2),
+                        torchvision.transforms.ToTensor(),
+                    ]
+                )
+            else:
+                self.train_transform = torchvision.transforms.Compose(
+                    [
+                        torchvision.transforms.RandomResizedCrop(size=size),
+                        torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
+                        torchvision.transforms.RandomApply([color_jitter], p=0.8),
+                        torchvision.transforms.RandomGrayscale(p=0.2),
+                        torchvision.transforms.ToTensor(),
+                    ]
+                )
+
+        self.test_transform = torchvision.transforms.Compose(
+            [
+                torchvision.transforms.Resize(size=size),
+                torchvision.transforms.ToTensor(),
+            ]
+        )
+
+        self.fine_tune_transform = torchvision.transforms.Compose(
+            [
+                torchvision.transforms.ToPILImage(mode='RGB'),
+                torchvision.transforms.Resize(size=size),
+                torchvision.transforms.ToTensor(),
+            ]
+        )
+
+    def __call__(self, x):
+        return self.train_transform(x), self.train_transform(x)
+
+
+class GaussianBlur(object):
+    """blur a single image on CPU"""
+
+    def __init__(self, kernel_size):
+        radias = kernel_size // 2
+        kernel_size = radias * 2 + 1
+        self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
+                                stride=1, padding=0, bias=False, groups=3)
+        self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
+                                stride=1, padding=0, bias=False, groups=3)
+        self.k = kernel_size
+        self.r = radias
+
+        self.blur = nn.Sequential(
+            nn.ReflectionPad2d(radias),
+            self.blur_h,
+            self.blur_v
+        )
+
+        self.pil_to_tensor = transforms.ToTensor()
+        self.tensor_to_pil = transforms.ToPILImage()
+
+    def __call__(self, img):
+        img = self.pil_to_tensor(img).unsqueeze(0)
+
+        sigma = np.random.uniform(0.1, 2.0)
+        x = np.arange(-self.r, self.r + 1)
+        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
+        x = x / x.sum()
+        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
+
+        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
+        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
+
+        with torch.no_grad():
+            img = self.blur(img)
+            img = img.squeeze()
+
+        img = self.tensor_to_pil(img)
+
+        return img

+ 33 - 0
applications/fedssl/utils.py

@@ -0,0 +1,33 @@
+import torch
+
+import transform
+from model import SimSiam, MoCo
+
+
+def get_transformation(model):
+    if model == SimSiam:
+        transformation = transform.SimSiamTransform
+    elif model == MoCo:
+        transformation = transform.MoCoTransform
+    else:
+        transformation = transform.SimCLRTransform
+    return transformation
+
+
+def calculate_model_distance(m1, m2):
+    distance, count = 0, 0
+    d1, d2 = m1.state_dict(), m2.state_dict()
+    for name, param in m1.named_parameters():
+        if 'conv' in name and 'weight' in name:
+            distance += torch.dist(d1[name].detach().clone().view(1, -1), d2[name].detach().clone().view(1, -1), 2)
+            count += 1
+    return distance / count
+
+
+def normalize(arr):
+    maxx = max(arr)
+    minn = min(arr)
+    diff = maxx - minn
+    if diff == 0:
+        return arr
+    return [(x - minn) / diff for x in arr]

+ 1 - 2
docs/en/index.rst

@@ -1,4 +1,4 @@
-Welcome to MMDetection's documentation!
+Welcome to EasyFL's documentation!
 =======================================
 
 .. toctree::
@@ -9,7 +9,6 @@ Welcome to MMDetection's documentation!
 
 .. toctree::
    :maxdepth: 2
-   :caption: Get Started
 
    get_started.md
 

+ 2 - 2
docs/en/projects.md

@@ -6,8 +6,8 @@ We have been doing research on federated learning for several years and publishe
 
 We have released the following implementations of federated learning applications:
 
-- FedReID: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedreid) for "Performance Optimization for Federated Person Re-identification via Benchmark Analysis", _ACMMM'2020_.
-
+- FedReID: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedreid) for [Performance Optimization for Federated Person Re-identification via Benchmark Analysis](https://dl.acm.org/doi/10.1145/3394171.3413814) (_ACMMM'2020_).
+- FedSSL: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedssl) for two papers: [Divergence-aware Federated Self-Supervised Learning](https://openreview.net/forum?id=oVE1z8NlNe) (_ICLR'2022_)  and [Collaborative Unsupervised Visual Representation Learning From Decentralized Data](https://openaccess.thecvf.com/content/ICCV2021/html/Zhuang_Collaborative_Unsupervised_Visual_Representation_Learning_From_Decentralized_Data_ICCV_2021_paper.html) (_ICCV'2021_)
 
 ## Papers
 

+ 1 - 1
easyfl/coordinator.py

@@ -43,7 +43,7 @@ class Coordinator(object):
         self._client_class = None
         self.tracker = None
 
-    def init(self, conf, init_all=False):
+    def init(self, conf, init_all=True):
         """Initialize coordinator
 
         Args: