Pārlūkot izejas kodu

[Feature] Federated Unsupervised Person Re-identification (#14)

Zhuang Weiming 1 gadu atpakaļ
vecāks
revīzija
20189eced1
30 mainītis faili ar 1643 papildinājumiem un 0 dzēšanām
  1. 1 0
      .gitignore
  2. 83 0
      applications/fedureid/README.md
  3. 0 0
      applications/fedureid/__init__.py
  4. 30 0
      applications/fedureid/config.yaml
  5. 35 0
      applications/fedureid/dataset.py
  6. 111 0
      applications/fedureid/evaluate.py
  7. BIN
      applications/fedureid/images/datasets.png
  8. BIN
      applications/fedureid/images/fedureid.png
  9. BIN
      applications/fedureid/images/results.png
  10. 240 0
      applications/fedureid/main.py
  11. 10 0
      applications/fedureid/reid/__init__.py
  12. 236 0
      applications/fedureid/reid/bottomup.py
  13. 10 0
      applications/fedureid/reid/evaluation_metrics/__init__.py
  14. 19 0
      applications/fedureid/reid/evaluation_metrics/classification.py
  15. 130 0
      applications/fedureid/reid/evaluation_metrics/ranking.py
  16. 113 0
      applications/fedureid/reid/evaluators.py
  17. 41 0
      applications/fedureid/reid/exclusive_loss.py
  18. 7 0
      applications/fedureid/reid/feature_extraction/__init__.py
  19. 30 0
      applications/fedureid/reid/feature_extraction/cnn.py
  20. 48 0
      applications/fedureid/reid/models/__init__.py
  21. 61 0
      applications/fedureid/reid/models/end2end.py
  22. 56 0
      applications/fedureid/reid/models/model.py
  23. 113 0
      applications/fedureid/reid/models/resnet.py
  24. 87 0
      applications/fedureid/reid/trainers.py
  25. 21 0
      applications/fedureid/reid/utils/__init__.py
  26. 23 0
      applications/fedureid/reid/utils/meters.py
  27. 11 0
      applications/fedureid/reid/utils/osutils.py
  28. 60 0
      applications/fedureid/reid/utils/serialization.py
  29. 1 0
      applications/fedureid/reid/utils/transform/__init__.py
  30. 66 0
      applications/fedureid/reid/utils/transform/transforms.py

+ 1 - 0
.gitignore

@@ -6,6 +6,7 @@ __pycache__
 *.xls
 *.xlsx
 *.egg-info
+.idea/
 docs/build
 dist/
 data/

+ 83 - 0
applications/fedureid/README.md

@@ -0,0 +1,83 @@
+# Federated Unsupervised Person Re-identification (FedUReID)
+
+This repository implements federated unsupervised person re-identification (FedUReID). FedUReID learnS person ReID models without any labels while preserving privacy. FedUReID enables in-situ model training on edges with unlabeled data. A cloud server aggregates models from edges instead of centralizing raw data to preserve data privacy. Extensive experiments on eight person ReID datasets demonstrate that FedUReID not only achieves higher accuracy but also reduces computation cost by 29%.
+
+The paper is accepted in ACMMM 2021 - **[Joint optimization in edge-cloud continuum for federated unsupervised person re-identification](https://arxiv.org/abs/2108.06493)**
+
+System architecture and workflow.
+
+<img src="images/fedureid.png" width="700">
+
+## Prerequisite
+
+It requires the following Python libraries:
+```
+torch
+torchvision
+easyfl
+scikit_learn==0.22.2.post1
+```
+
+Please refer to the [documentation](https://easyfl.readthedocs.io/en/latest/get_started.html#installation) to install `easyfl`.
+
+## Datasets
+
+**We use 8 popular ReID datasets for the benchmark.**
+<img src="images/datasets.png" width="700">
+
+>
+> Please [email us](weiming001@e.ntu.edu.sg) to request for the datasets with:
+> 1. A short self-introduction.
+> 2. The purposes of using these datasets.
+>
+> *⚠️ Further distribution of the datasets are prohibited.*
+
+## Run the experiments
+
+Put the processed datasets in `data_dir` and run the experiments with the following scripts.
+
+```
+python main.py --data_dir ${data_dir} --dataset Market Duke cuhk03 cuhk01 prid viper 3dpes ilids
+```
+
+You can refer to the `main.py` to run experiments with more options and configurations.
+
+> Note: you can run experiments with multiple GPUs by setting `--gpu`. The default implementation supports running with multiple GPUs in a _slurm cluster_. You may need to modify `main.py` to use `multiprocess`.
+
+Currently, we provide the implementation of the baseline and personalized epoch method described in the paper. 
+
+## Results
+
+<img src="images/results.png" width="700">
+
+
+## Citation
+```
+@inproceedings{zhuang2021fedureid,
+  title={Joint optimization in edge-cloud continuum for federated unsupervised person re-identification},
+  author={Zhuang, Weiming and Wen, Yonggang and Zhang, Shuai},
+  booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
+  pages={433--441},
+  year={2021}
+}
+
+@inproceedings{zhuang2020fedreid,
+  title={Performance Optimization of Federated Person Re-identification via Benchmark Analysis},
+  author={Zhuang, Weiming and Wen, Yonggang and Zhang, Xuesen and Gan, Xin and Yin, Daiying and Zhou, Dongzhan and Zhang, Shuai and Yi, Shuai},
+  booktitle={Proceedings of the 28th ACM International Conference on Multimedia},
+  pages={955--963},
+  year={2020}
+}
+
+@article{zhuang2023fedreid,
+  title={Optimizing performance of federated person re-identification: Benchmarking and analysis},
+  author={Zhuang, Weiming and Gan, Xin and Wen, Yonggang and Zhang, Shuai},
+  journal={ACM Transactions on Multimedia Computing, Communications and Applications},
+  volume={19},
+  number={1s},
+  pages={1--18},
+  year={2023},
+  publisher={ACM New York, NY}
+}
+```
+

+ 0 - 0
applications/fedureid/__init__.py


+ 30 - 0
applications/fedureid/config.yaml

@@ -0,0 +1,30 @@
+tid: "fedureid"
+server:
+  test_all: True
+  clients_per_round: 8
+  test_every: 1
+  rounds: 20
+  batch_size: 32
+  random_selection: False
+resource_heterogeneous:
+  grouping_strategy: "none"
+client:
+  local_epoch: 1
+  track: False
+  batch_size: 32
+  optimizer:
+    type: "SGD"
+    lr: 0.05
+    momentum: 0.9
+  buc:
+    batch_size: 16
+    eval_batch_size: 64
+    size_penalty: 0.003
+    merge_percent: 0.05
+    feature_size: 2048
+    upload_frequency: 1
+    global_evaluation: False
+    initial_epochs: 20
+    local_epochs: 1
+test_mode: "test_in_client"
+test_method: "average"

+ 35 - 0
applications/fedureid/dataset.py

@@ -0,0 +1,35 @@
+import os
+
+from reid.utils.transform.transforms import TRANSFORM_TRAIN_LIST, TRANSFORM_VAL_LIST
+from easyfl.datasets import FederatedImageDataset
+
+
+def prepare_train_data(db_names, data_dir):
+    client_ids = []
+    roots = []
+    for d in db_names:
+        client_ids.append(d)
+        data_path = os.path.join(data_dir, d, 'pytorch')
+        roots.append(os.path.join(data_path, 'train_all'))
+    data = FederatedImageDataset(root=roots,
+                                 simulated=True,
+                                 do_simulate=False,
+                                 transform=TRANSFORM_TRAIN_LIST,
+                                 client_ids=client_ids)
+    return data
+
+
+def prepare_test_data(db_names, data_dir):
+    roots = []
+    client_ids = []
+    for d in db_names:
+        test_gallery = os.path.join(data_dir, d, 'pytorch', 'gallery')
+        test_query = os.path.join(data_dir, d, 'pytorch', 'query')
+        roots.extend([test_gallery, test_query])
+        client_ids.extend(["{}_{}".format(d, "gallery"), "{}_{}".format(d, "query")])
+    data = FederatedImageDataset(root=roots,
+                                 simulated=True,
+                                 do_simulate=False,
+                                 transform=TRANSFORM_VAL_LIST,
+                                 client_ids=client_ids)
+    return data

+ 111 - 0
applications/fedureid/evaluate.py

@@ -0,0 +1,111 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+import scipy.io
+
+
+def extract_feature(model, dataloaders, device, ms=[1]):
+    features = torch.FloatTensor()
+    model = model.to(device)
+    for data in dataloaders:
+        img, label = data
+        n, c, h, w = img.size()
+        ff = torch.FloatTensor(n, 2048).zero_().to(device)
+        for i in range(2):
+            if i == 1:
+                img = fliplr(img)
+            input_img = Variable(img.to(device))
+            for scale in ms:
+                if scale != 1:
+                    # bicubic is only  available in pytorch>= 1.1
+                    input_img = nn.functional.interpolate(input_img, scale_factor=scale, mode='bicubic',
+                                                          align_corners=False)
+                _, outputs = model(input_img)
+                ff += outputs
+        # # norm feature
+        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
+        ff = ff.div(fnorm.expand_as(ff))
+        features = torch.cat((features, ff.data.cpu()), 0)
+    return features
+
+
+def test_evaluate(path, dataset, device):
+    result = scipy.io.loadmat(path)
+    query_feature = torch.FloatTensor(result['query_f'])
+    query_cam = result['query_cam'][0]
+    query_label = result['query_label'][0]
+    gallery_feature = torch.FloatTensor(result['gallery_f'])
+    gallery_cam = result['gallery_cam'][0]
+    gallery_label = result['gallery_label'][0]
+    query_feature = query_feature.to(device)
+    gallery_feature = gallery_feature.to(device)
+    CMC = torch.IntTensor(len(gallery_label)).zero_()
+    ap = 0.0
+    for i in range(len(query_label)):
+        ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], query_cam[i], gallery_feature, gallery_label,
+                                   gallery_cam)
+        if CMC_tmp[0] == -1:
+            continue
+        CMC = CMC + CMC_tmp
+        ap += ap_tmp
+    CMC = CMC.float()
+    CMC = CMC / len(query_label)  # average CMC
+    return CMC[0], CMC[4], CMC[9], ap / len(query_label)
+
+
+def evaluate(qf, ql, qc, gf, gl, gc):
+    query = qf.view(-1, 1)
+    score = torch.mm(gf, query)
+    score = score.squeeze(1).cpu()
+    score = score.numpy()
+    # predict index
+    index = np.argsort(score)  # from small to large
+    index = index[::-1]
+    # good index
+    query_index = np.argwhere(gl == ql)
+    camera_index = np.argwhere(gc == qc)
+
+    good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
+    junk_index1 = np.argwhere(gl == -1)
+    junk_index2 = np.intersect1d(query_index, camera_index)
+    junk_index = np.append(junk_index2, junk_index1)  # .flatten())
+
+    CMC_tmp = compute_mAP(index, good_index, junk_index)
+    return CMC_tmp
+
+
+def compute_mAP(index, good_index, junk_index):
+    ap = 0
+    cmc = torch.IntTensor(len(index)).zero_()
+    if good_index.size == 0:  # if empty
+        cmc[0] = -1
+        return ap, cmc
+
+    # remove junk_index
+    mask = np.in1d(index, junk_index, invert=True)
+    index = index[mask]
+
+    # find good_index index
+    ngood = len(good_index)
+    mask = np.in1d(index, good_index)
+    rows_good = np.argwhere(mask == True)
+    rows_good = rows_good.flatten()
+
+    cmc[rows_good[0]:] = 1
+    for i in range(ngood):
+        d_recall = 1.0 / ngood
+        precision = (i + 1) * 1.0 / (rows_good[i] + 1)
+        if rows_good[i] != 0:
+            old_precision = i * 1.0 / rows_good[i]
+        else:
+            old_precision = 1.0
+        ap = ap + d_recall * (old_precision + precision) / 2
+
+    return ap, cmc
+
+
+def fliplr(img):
+    inv_idx = torch.arange(img.size(3) - 1, -1, -1).long()  # N x C x H x W
+    img_flip = img.index_select(3, inv_idx)
+    return img_flip

BIN
applications/fedureid/images/datasets.png


BIN
applications/fedureid/images/fedureid.png


BIN
applications/fedureid/images/results.png


+ 240 - 0
applications/fedureid/main.py

@@ -0,0 +1,240 @@
+import argparse
+import os
+import time
+
+import torch._utils
+
+import easyfl
+from dataset import prepare_train_data, prepare_test_data
+from reid.bottomup import *
+from reid.models.model import BUCModel
+from easyfl.client.base import BaseClient
+from easyfl.distributed import slurm
+from easyfl.distributed.distributed import CPU
+from easyfl.pb import common_pb2 as common_pb
+from easyfl.pb import server_service_pb2 as server_pb
+from easyfl.protocol import codec
+from easyfl.tracking import metric
+
+logger = logging.getLogger(__name__)
+
+
+LOCAL_TEST = "local_test"
+GLOBAL_TEST = "global_test"
+
+RELABEL_LOCAL = "local"
+RELABEL_GLOBAL = "global"
+
+
+class FedUReIDClient(BaseClient):
+    def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0,
+                 is_remote=False, local_port=23000, server_addr="localhost:22999", tracker_addr="localhost:12666"):
+        super(FedUReIDClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time,
+                                                 is_remote, local_port, server_addr, tracker_addr)
+        logger.info(conf)
+        self.conf = conf
+        self.current_step = -1
+
+        self._local_model = None  # for caching local model in testing
+        self.gallery_cam = None
+        self.gallery_label = None
+        self.query_cam = None
+        self.query_label = None
+        self.test_gallery_loader = None
+        self.test_query_loader = None
+
+        self.train_data = train_data
+        self.test_data = test_data
+
+        self.labeled_ys = self.train_data.data[self.cid]['y']
+        self.unlabeled_ys = [i for i in range(len(self.labeled_ys))]
+        # initialize unlabeled transform
+        self.train_data.data[self.cid]['y'] = self.unlabeled_ys
+
+        num_classes = len(np.unique(np.array(self.unlabeled_ys)))
+
+        merge_percent = conf.buc.merge_percent
+        self.nums_to_merge = int(num_classes * conf.buc.merge_percent)
+        self.steps = int(1 / merge_percent) - 1
+
+        self.buc = BottomUp(cid=self.cid,
+                            model=self.model,  # model is None
+                            batch_size=conf.buc.batch_size,
+                            eval_batch_size=conf.buc.eval_batch_size,
+                            num_classes=num_classes,
+                            train_data=self.train_data,
+                            test_data=self.test_data,
+                            device=device,
+                            initial_epochs=conf.buc.initial_epochs,
+                            local_epochs=conf.buc.local_epochs,
+                            embedding_feature_size=conf.buc.feature_size,
+                            seed=conf.seed)
+
+    def train(self, conf, device=CPU):
+        logger.info("--------- training -------- cid: {}, on {}".format(self.cid, device))
+
+        start_time = time.time()
+
+        step_to_upload = self.current_step + conf.buc.upload_frequency
+        total_steps = self.steps
+
+        while self.current_step < step_to_upload:
+            self.current_step += 1
+            logger.info("current step: {}".format(self.current_step))
+            logger.info("training transform amount: {}".format(len(self.unlabeled_ys)))
+
+            if conf.buc.relabel == RELABEL_GLOBAL:
+                if self.current_step > 0:
+                    logger.info("-------- bottom-up clustering: relabel train transform with global aggregated model")
+                    self.unlabeled_ys = self.buc.relabel_train_data(device,
+                                                                    self.unlabeled_ys,
+                                                                    self.labeled_ys,
+                                                                    self.nums_to_merge,
+                                                                    size_penalty=conf.buc.size_penalty)
+                    self.train_data.data[self.cid]['y'] = self.unlabeled_ys
+
+            self.buc.set_model(self.model, self.current_step)
+            model = self.buc.train(self.current_step, conf.buc.dynamic_epoch)
+            self._local_model = copy.deepcopy(self.model)
+            self.model.load_state_dict(model.state_dict())
+
+            rank1, rank5, rank10, mAP = self.buc.evaluate(self.cid)
+            logger.info("Local test {}, step {}, mAP: {:4.2%}, Rank@1: {:4.2%}, Rank@5: {:4.2%}, Rank@10: {:4.2%}"
+                        .format(self.cid, self.current_step, mAP, rank1, rank5, rank10))
+
+            if self.current_step == total_steps:
+                logger.info("Total steps just reached, force global update")
+                break
+
+            # get new train transform for the next iteration
+            if self.current_step > total_steps:
+                logger.info("Total steps reached, skip relabeling")
+                continue
+
+            if conf.buc.relabel == RELABEL_LOCAL:
+                logger.info("-------- bottom-up clustering: relabel train transform with local trained model")
+                self.unlabeled_ys = self.buc.relabel_train_data(device,
+                                                                self.unlabeled_ys,
+                                                                self.labeled_ys,
+                                                                self.nums_to_merge,
+                                                                size_penalty=conf.buc.size_penalty)
+
+                self.train_data.data[self.cid]['y'] = self.unlabeled_ys
+
+        self.save_model(LOCAL_TEST, device)
+        self.current_round_time = time.time() - start_time
+        logger.info("Local training time {}".format(self.current_round_time))
+        self.track(metric.TRAIN_TIME, self.current_round_time)
+
+        self.model = self.model.to(device)
+
+    def test(self, conf, device=CPU):
+        rank1 = 0
+        if conf.buc.global_evaluation:
+            logger.info("-------- evaluation -------- {}: {}".format(GLOBAL_TEST, self.cid))
+            rank1, rank5, rank10, mAP = self.buc.evaluate(self.cid, self.model)
+            logger.info("Global test {}, step {}, mAP: {:4.2%}, Rank@1: {:4.2%}, Rank@5: {:4.2%}, Rank@10: {:4.2%}"
+                        .format(self.cid, self.current_step, mAP, rank1, rank5, rank10))
+            self.save_model(GLOBAL_TEST, device)
+
+        self._upload_holder = server_pb.UploadContent(
+            data=codec.marshal(server_pb.Performance(accuracy=rank1, loss=0)),  # loss not applicable
+            type=common_pb.DATA_TYPE_PERFORMANCE,
+            data_size=len(self.train_data.data[self.cid]['x']),
+        )
+
+    def save_model(self, typ=LOCAL_TEST, device=CPU):
+        path = os.path.join(os.getcwd(), "saved_models")
+        if not os.path.exists(path):
+            os.makedirs(path)
+        if typ == GLOBAL_TEST:
+            save_path = os.path.join(path, "{}_global_model_{}.pth".format(self.current_step, time.time()))
+            if device == 0 or device == CPU:
+                torch.save(self.model.cpu().state_dict(), save_path)
+        else:
+            save_path = os.path.join(path, "{}_{}_local_model_{}.pth".format(self.current_step, self.cid, time.time()))
+            torch.save(self.model.cpu().state_dict(), save_path)
+        logger.info("save model {}".format(save_path))
+
+
+def get_merge_percent(num_images, num_identities, rounds):
+    nums_to_merge = int((num_images - num_identities) / rounds)
+    merge_percent = nums_to_merge / num_images
+    return merge_percent, nums_to_merge
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='')
+    parser.add_argument('--data_dir', type=str, metavar='PATH', default="datasets/fedreid_data")
+    parser.add_argument("--datasets", nargs="+", default=["ilids"])
+    parser.add_argument('--batch_size', type=int, default=16, help='training batch size')
+    parser.add_argument('--upload_frequency', type=int, default=1, help='frequency of upload for aggregation')
+    parser.add_argument('--merge_percent', type=float, default=0.05, help='merge percentage of each step')
+    parser.add_argument('--steps', type=int, default=0, help='steps to decide merge percent')
+    parser.add_argument('--initial_epochs', type=int, default=20, help='local epochs for first step/round')
+    parser.add_argument('--local_epochs', type=int, default=1, help='local epochs after first step/round')
+    parser.add_argument('--dynamic_epoch', default=False, action='store_true', help='dynamic local epochs')
+    parser.add_argument('--relabel', type=str, default='local', help='use "local" or "global" model to relabel')
+    parser.add_argument('--merge', default=False, action='store_true')
+    args = parser.parse_args()
+
+    print("args:", args)
+
+    # MAIN
+    train_data = prepare_train_data(args.datasets, args.data_dir)
+    test_data = prepare_test_data(args.datasets, args.data_dir)
+    easyfl.register_dataset(train_data, test_data)
+    easyfl.register_model(BUCModel)
+    easyfl.register_client(FedUReIDClient)
+
+    # configurations
+    global_evaluation = False
+    if args.steps:
+        rounds = args.steps
+    else:
+        rounds = int(1 / args.merge_percent)
+
+    config = {
+        "server": {
+            "rounds": rounds,
+        },
+        "client": {
+            "buc": {
+                "global_evaluation": global_evaluation,
+                "relabel": args.relabel,
+                "initial_epochs": args.initial_epochs,
+                "local_epochs": args.local_epochs,
+                "dynamic_epoch": args.dynamic_epoch,
+                "batch_size": args.batch_size,
+                "upload_frequency": args.upload_frequency,
+                "merge_percent": args.merge_percent,
+                "steps": args.steps,
+            },
+            "datasets": args.datasets,
+        }
+    }
+
+    # For distributed training over multiple GPUs only
+    try:
+        rank, local_rank, world_size, host_addr = slurm.setup()
+        global_evaluation = True if world_size > 1 else False
+        config["client"]["buc"]["global_evaluation"] = global_evaluation
+        distributed_config = {
+            "gpu": world_size,
+            "distributed": {
+                "rank": rank,
+                "local_rank": local_rank,
+                "world_size": world_size,
+                "init_method": host_addr,
+                "backend": "nccl",
+            },
+        }
+        config.update(distributed_config)
+    except KeyError:
+        pass
+    config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yaml")
+    config = easyfl.load_config(config_file, config)
+
+    print("config:", config)
+    easyfl.init(config, init_all=True)
+    easyfl.run()

+ 10 - 0
applications/fedureid/reid/__init__.py

@@ -0,0 +1,10 @@
+from __future__ import absolute_import
+
+from . import evaluation_metrics
+from . import feature_extraction
+from . import models
+from . import utils
+from . import evaluators
+from . import trainers
+
+__version__ = '0.2.0'

+ 236 - 0
applications/fedureid/reid/bottomup.py

@@ -0,0 +1,236 @@
+import copy
+import logging
+import sys
+
+import numpy as np
+import torch
+
+from .evaluators import Evaluator, extract_features
+from .exclusive_loss import ExLoss
+from .trainers import Trainer
+from .utils.transform.transforms import TRANSFORM_VAL_LIST
+
+logger = logging.getLogger(__name__)
+
+
+class BottomUp:
+    def __init__(self,
+                 cid,
+                 model,
+                 batch_size,
+                 eval_batch_size,
+                 num_classes,
+                 train_data,
+                 test_data,
+                 device,
+                 embedding_feature_size=2048,
+                 initial_epochs=20,
+                 local_epochs=2,
+                 step_size=16,
+                 seed=0):
+        self.cid = cid
+        self.model = model
+        self.num_classes = num_classes
+        self.batch_size = batch_size
+        self.eval_batch_size = eval_batch_size
+        self.device = device
+
+        self.seed = seed
+
+        self.gallery_cam = None
+        self.gallery_label = None
+        self.query_cam = None
+        self.query_label = None
+        self.test_gallery_loader = None
+        self.test_query_loader = None
+
+        self.train_data = train_data
+        self.test_data = test_data
+
+        self.initial_epochs = initial_epochs
+        self.local_epochs = local_epochs
+        self.step_size = step_size
+
+        self.embedding_feature_size = embedding_feature_size
+
+        self.fixed_layer = False
+
+        self.old_features = None
+        self.feature_distance = 0
+
+        self.criterion = ExLoss(self.embedding_feature_size, self.num_classes, t=10).to(device)
+
+    def set_model(self, model, current_step):
+        if current_step == 0:
+            self.model = model.to(self.device)
+        else:
+            self.model.load_state_dict(model.state_dict())
+            self.model = self.model.to(self.device)
+
+    def train(self, step, dynamic_epoch=False):
+        self.model = self.model.train()
+
+        # adjust training epochs and learning rate
+        epochs = self.initial_epochs if step == 0 else self.local_epochs
+
+        init_lr = 0.1 if step == 0 else 0.01
+        step_size = self.step_size if step == 0 else sys.maxsize
+
+        logger.info("create train transform loader with batch size {}".format(self.batch_size))
+        loader = self.train_data.loader(self.batch_size, self.cid, seed=self.seed, num_workers=6)
+
+        # the base parameters for the backbone (e.g. ResNet50)
+        base_param_ids = set(map(id, self.model.CNN.base.parameters()))
+
+        # we fixed the first three blocks to save GPU memory
+        base_params_need_for_grad = filter(lambda p: p.requires_grad, self.model.CNN.base.parameters())
+
+        # params of the new layers
+        new_params = [p for p in self.model.parameters() if id(p) not in base_param_ids]
+
+        # set the learning rate for backbone to be 0.1 times
+        param_groups = [
+            {'params': base_params_need_for_grad, 'lr_mult': 0.1},
+            {'params': new_params, 'lr_mult': 1.0}]
+
+        optimizer = torch.optim.SGD(param_groups, lr=init_lr, momentum=0.9, weight_decay=5e-4, nesterov=True)
+
+        # change the learning rate by step
+        def adjust_lr(epoch, step_size):
+            lr = init_lr / (10 ** (epoch // step_size))
+            for g in optimizer.param_groups:
+                g['lr'] = lr * g.get('lr_mult', 1)
+
+        logger.info("number of epochs, {}: {}".format(self.cid, epochs))
+
+        """ main training process """
+        trainer = Trainer(self.model, self.criterion, self.device, fixed_layer=self.fixed_layer)
+        for epoch in range(epochs):
+            adjust_lr(epoch, step_size)
+            stop_local_training = trainer.train(epoch, loader, optimizer, print_freq=max(5, len(loader) // 30 * 10))
+            # Dynamically decide number of local epochs, based on conditions inside trainer.
+            if step > 0 and dynamic_epoch and stop_local_training:
+                logger.info("Dynamic epoch: in step {}, stop training {} after epoch {}".format(step, self.cid, epoch))
+                break
+        return self.model
+
+    def evaluate(self, cid, model=None):
+        # getting cid from argument is because of merged training
+        if model is None:
+            model = self.model
+        model = model.eval()
+        model = model.to(self.device)
+
+        gallery_id = '{}_{}'.format(cid, 'gallery')
+        query_id = '{}_{}'.format(cid, 'query')
+
+        logger.info("create test transform loader with batch size {}".format(self.eval_batch_size))
+        gallery_loader = self.test_data.loader(batch_size=self.eval_batch_size,
+                                               client_id=gallery_id,
+                                               shuffle=False,
+                                               num_workers=6)
+        query_loader = self.test_data.loader(batch_size=self.eval_batch_size,
+                                             client_id=query_id,
+                                             shuffle=False,
+                                             num_workers=6)
+
+        evaluator = Evaluator(model, self.test_data, query_id, gallery_id, self.device)
+        rank1, rank5, rank10, mAP = evaluator.evaluate(query_loader, gallery_loader)
+        return rank1, rank5, rank10, mAP
+
+    # New get_new_train_data
+    def relabel_train_data(self, device, unlabeled_ys, labeled_ys, nums_to_merge, size_penalty):
+        # extract feature/classifier
+        self.model = self.model.to(device)
+        loader = self.train_data.loader(self.batch_size,
+                                        self.cid,
+                                        shuffle=False,
+                                        num_workers=6,
+                                        transform=TRANSFORM_VAL_LIST)
+        features = extract_features(self.model, loader, device)
+
+        # calculate cosine distance of features
+        if self.old_features:
+            similarities = []
+            for old_feature, new_feature in zip(self.old_features, features):
+                m = torch.cosine_similarity(old_feature, new_feature, dim=0)
+                similarities.append(m)
+            self.feature_distance = 1 - sum(similarities) / len(similarities)
+            logger.info("Cosine distance between features, {}: {}".format(self.cid, self.feature_distance))
+        self.old_features = copy.deepcopy(features)
+
+        features = np.array([logit.numpy() for logit in features])
+
+        # images of the same cluster
+        label_to_images = {}
+        for idx, l in enumerate(unlabeled_ys):
+            label_to_images[l] = label_to_images.get(l, []) + [idx]
+
+        dists = self.calculate_distance(features)
+
+        idx1, idx2 = self.select_merge_data(features, unlabeled_ys, label_to_images, size_penalty, dists)
+
+        unlabeled_ys = self.relabel_new_train_data(idx1, idx2, labeled_ys, unlabeled_ys, nums_to_merge)
+
+        num_classes = len(np.unique(np.array(unlabeled_ys)))
+
+        # change the criterion classifier
+        self.criterion = ExLoss(self.embedding_feature_size, num_classes, t=10).to(device)
+
+        return unlabeled_ys
+
+    def relabel_new_train_data(self, idx1, idx2, labeled_ys, label, num_to_merge):
+        correct = 0
+        num_before_merge = len(np.unique(np.array(label)))
+        # merge clusters with minimum dissimilarity
+        for i in range(len(idx1)):
+            label1 = label[idx1[i]]
+            label2 = label[idx2[i]]
+            if label1 < label2:
+                label = [label1 if x == label2 else x for x in label]
+            else:
+                label = [label2 if x == label1 else x for x in label]
+            if labeled_ys[idx1[i]] == labeled_ys[idx2[i]]:
+                correct += 1
+            num_merged = num_before_merge - len(np.sort(np.unique(np.array(label))))
+            if num_merged == num_to_merge:
+                break
+        # set new label to the new training transform
+        unique_label = np.sort(np.unique(np.array(label)))
+        for i in range(len(unique_label)):
+            label_now = unique_label[i]
+            label = [i if x == label_now else x for x in label]
+
+        self.train_data.data[self.cid]['y'] = label
+
+        num_after_merge = len(np.unique(np.array(label)))
+        logger.info("num of label before merge: {}, after merge: {}, sub: {}".format(
+            num_before_merge, num_after_merge, num_before_merge - num_after_merge))
+        return label
+
+    def calculate_distance(self, u_feas):
+        # calculate distance between features
+        x = torch.from_numpy(u_feas)
+        y = x
+        m = len(u_feas)
+        dists = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, m) + \
+                torch.pow(y, 2).sum(dim=1, keepdim=True).expand(m, m).t()
+        dists.addmm_(1, -2, x, y.t())
+        return dists
+
+    def select_merge_data(self, u_feas, label, label_to_images, ratio_n, dists):
+        dists.add_(torch.tril(100000 * torch.ones(len(u_feas), len(u_feas))))
+
+        cnt = torch.FloatTensor([len(label_to_images[label[idx]]) for idx in range(len(u_feas))])
+        dists += ratio_n * (cnt.view(1, len(cnt)) + cnt.view(len(cnt), 1))
+
+        for idx in range(len(u_feas)):
+            for j in range(idx + 1, len(u_feas)):
+                if label[idx] == label[j]:
+                    dists[idx, j] = 100000
+
+        dists = dists.numpy()
+        ind = np.unravel_index(np.argsort(dists, axis=None), dists.shape)
+        idx1 = ind[0]
+        idx2 = ind[1]
+        return idx1, idx2

+ 10 - 0
applications/fedureid/reid/evaluation_metrics/__init__.py

@@ -0,0 +1,10 @@
+from __future__ import absolute_import
+
+from .classification import accuracy
+from .ranking import cmc, mean_ap
+
+__all__ = [
+    'accuracy',
+    'cmc',
+    'mean_ap',
+]

+ 19 - 0
applications/fedureid/reid/evaluation_metrics/classification.py

@@ -0,0 +1,19 @@
+from __future__ import absolute_import
+
+from ..utils import to_torch
+
+
+def accuracy(output, target, topk=(1,)):
+    output, target = to_torch(output), to_torch(target)
+    maxk = max(topk)
+    batch_size = target.size(0)
+
+    _, pred = output.topk(maxk, 1, True, True)
+    pred = pred.t()
+    correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+    ret = []
+    for k in topk:
+        correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
+        ret.append(correct_k.mul_(1. / batch_size))
+    return ret

+ 130 - 0
applications/fedureid/reid/evaluation_metrics/ranking.py

@@ -0,0 +1,130 @@
+from __future__ import absolute_import
+from collections import defaultdict
+
+import numpy as np
+from sklearn.metrics.base import _average_binary_score
+from sklearn.metrics import precision_recall_curve, auc
+# from sklearn.metrics import average_precision_score
+
+
+from ..utils import to_numpy
+
+
+def _unique_sample(ids_dict, num):
+    mask = np.zeros(num, dtype=np.bool)
+    for _, indices in ids_dict.items():
+        i = np.random.choice(indices)
+        mask[i] = True
+    return mask
+
+
+def average_precision_score(y_true, y_score, average="macro",
+                            sample_weight=None):
+    def _binary_average_precision(y_true, y_score, sample_weight=None):
+        precision, recall, thresholds = precision_recall_curve(
+            y_true, y_score, sample_weight=sample_weight)
+        return auc(recall, precision)
+
+    return _average_binary_score(_binary_average_precision, y_true, y_score,
+                                 average, sample_weight=sample_weight)
+
+
+def cmc(distmat, query_ids=None, gallery_ids=None,
+        query_cams=None, gallery_cams=None, topk=100,
+        separate_camera_set=False,
+        single_gallery_shot=False,
+        first_match_break=False):
+    distmat = to_numpy(distmat)
+    m, n = distmat.shape
+    # Fill up default values
+    if query_ids is None:
+        query_ids = np.arange(m)
+    if gallery_ids is None:
+        gallery_ids = np.arange(n)
+    if query_cams is None:
+        query_cams = np.zeros(m).astype(np.int32)
+    if gallery_cams is None:
+        gallery_cams = np.ones(n).astype(np.int32)
+    # Ensure numpy array
+    query_ids = np.asarray(query_ids)
+    gallery_ids = np.asarray(gallery_ids)
+    query_cams = np.asarray(query_cams)
+    gallery_cams = np.asarray(gallery_cams)
+    # Sort and find correct matches
+    indices = np.argsort(distmat, axis=1)
+    matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
+    # Compute CMC for each query
+    ret = np.zeros(topk)
+    num_valid_queries = 0
+    for i in range(m):
+        # Filter out the same id and same camera
+        valid = ((gallery_ids[indices[i]] != query_ids[i]) |
+                 (gallery_cams[indices[i]] != query_cams[i]))
+        if separate_camera_set:
+            # Filter out samples from same camera
+            valid &= (gallery_cams[indices[i]] != query_cams[i])
+        if not np.any(matches[i, valid]): continue
+        if single_gallery_shot:
+            repeat = 10
+            gids = gallery_ids[indices[i][valid]]
+            inds = np.where(valid)[0]
+            ids_dict = defaultdict(list)
+            for j, x in zip(inds, gids):
+                ids_dict[x].append(j)
+        else:
+            repeat = 1
+        for _ in range(repeat):
+            if single_gallery_shot:
+                # Randomly choose one instance for each id
+                sampled = (valid & _unique_sample(ids_dict, len(valid)))
+                index = np.nonzero(matches[i, sampled])[0]
+            else:
+                index = np.nonzero(matches[i, valid])[0]
+            delta = 1. / (len(index) * repeat)
+            for j, k in enumerate(index):
+                if k - j >= topk: break
+                if first_match_break:
+                    ret[k - j] += 1
+                    break
+                ret[k - j] += delta
+        num_valid_queries += 1
+    if num_valid_queries == 0:
+        raise RuntimeError("No valid query")
+    return ret.cumsum() / num_valid_queries
+
+
+def mean_ap(distmat, query_ids=None, gallery_ids=None,
+            query_cams=None, gallery_cams=None):
+    distmat = to_numpy(distmat)
+    m, n = distmat.shape
+    # Fill up default values
+    if query_ids is None:
+        query_ids = np.arange(m)
+    if gallery_ids is None:
+        gallery_ids = np.arange(n)
+    if query_cams is None:
+        query_cams = np.zeros(m).astype(np.int32)
+    if gallery_cams is None:
+        gallery_cams = np.ones(n).astype(np.int32)
+    # Ensure numpy array
+    query_ids = np.asarray(query_ids)
+    gallery_ids = np.asarray(gallery_ids)
+    query_cams = np.asarray(query_cams)
+    gallery_cams = np.asarray(gallery_cams)
+    # Sort and find correct matches
+    indices = np.argsort(distmat, axis=1)
+    matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
+    # Compute AP for each query
+    aps = []
+    for i in range(m):
+        # Filter out the same id and same camera
+        valid = ((gallery_ids[indices[i]] != query_ids[i]) |
+                 (gallery_cams[indices[i]] != query_cams[i]))
+        y_true = matches[i, valid]
+        y_score = -distmat[i][indices[i]][valid]
+        if not np.any(y_true): continue
+        aps.append(average_precision_score(y_true, y_score))
+    if len(aps) == 0:
+        raise RuntimeError("No valid query")
+    return np.mean(aps)
+

+ 113 - 0
applications/fedureid/reid/evaluators.py

@@ -0,0 +1,113 @@
+from __future__ import print_function, absolute_import
+
+import logging
+import os
+
+import torch
+from torch.backends import cudnn
+
+from .evaluation_metrics import cmc, mean_ap
+from .feature_extraction import extract_cnn_feature
+
+logger = logging.getLogger(__name__)
+
+
+# extract features for fed transform format
+def extract_features(model, data_loader, device, print_freq=1, metric=None):
+    cudnn.benchmark = False
+    model.eval()
+
+    features = []
+    logger.info("extracting features...")
+    for i, (inputs, targets) in enumerate(data_loader):
+        inputs = inputs.to(device)
+        _fcs, pool5s = extract_cnn_feature(model, inputs)
+        features.extend(pool5s)
+    return features
+
+
+def pairwise_distance(query_features, gallery_features, metric=None):
+    x = torch.cat([f.unsqueeze(0) for f in query_features], 0)
+    y = torch.cat([f.unsqueeze(0) for f in gallery_features], 0)
+
+    m, n = x.size(0), y.size(0)
+    x = x.view(m, -1)
+    y = y.view(n, -1)
+    if metric is not None:
+        x = metric.transform(x)
+        y = metric.transform(y)
+    dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
+           torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
+    dist.addmm_(1, -2, x, y.t())
+    return dist
+
+
+def evaluate_all(distmat, query_ids, gallery_ids, query_cams, gallery_cams, cmc_topk=(1, 5, 10, 20)):
+    # Compute mean AP
+    mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams)
+
+    # Compute all kinds of CMC scores
+    cmc_configs = {
+        'market1501': dict(separate_camera_set=False,
+                           single_gallery_shot=False,
+                           first_match_break=True)}
+    cmc_scores = {name: cmc(distmat, query_ids, gallery_ids,
+                            query_cams, gallery_cams, **params)
+                  for name, params in cmc_configs.items()}
+
+    print('Mean AP: {:4.2%}'.format(mAP))
+    print('CMC Scores:')
+    for k in cmc_topk:
+        print('  top-{:<4}{:12.2%}'
+              .format(k,
+                      cmc_scores['market1501'][k - 1]))
+
+    # Use the allshots cmc top-1 score for validation criterion
+    return cmc_scores['market1501'][0], cmc_scores['market1501'][4], cmc_scores['market1501'][9], mAP
+
+
+class Evaluator(object):
+    def __init__(self, model, test_data, query_id, gallery_id, device, is_print=False):
+        super(Evaluator, self).__init__()
+        self.model = model
+        self.test_data = test_data
+        self.device = device
+
+        gallery_path = [(self.test_data.data[gallery_id]['x'][i],
+                         self.test_data.data[gallery_id]['y'][i])
+                        for i in range(len(self.test_data.data[gallery_id]['y']))]
+        query_path = [(self.test_data.data[query_id]['x'][i],
+                       self.test_data.data[query_id]['y'][i])
+                      for i in range(len(self.test_data.data[query_id]['y']))]
+        gallery_cam, gallery_label = get_id(gallery_path)
+        self.gallery_cam = gallery_cam
+        self.gallery_label = gallery_label
+        query_cam, query_label = get_id(query_path)
+        self.query_cam = query_cam
+        self.query_label = query_label
+
+    def evaluate(self, query_loader, gallery_loader, metric=None):
+        query_features = extract_features(self.model, query_loader, self.device)
+        gallery_features = extract_features(self.model, gallery_loader, self.device)
+        distmat = pairwise_distance(query_features, gallery_features, metric=metric)
+        return evaluate_all(distmat, self.query_label, self.gallery_label, self.query_cam, self.gallery_cam)
+
+
+def get_id(img_path):
+    camera_id = []
+    labels = []
+    for p, v in img_path:
+        filename = os.path.basename(p)
+        if filename[:3] != 'cam':
+            label = filename[0:4]
+            camera = filename.split('c')[1]
+            camera = camera.split('s')[0]
+        else:
+            label = filename.split('_')[2]
+            camera = filename.split('_')[1]
+        if label[0:2] == '-1':
+            labels.append(-1)
+        else:
+            labels.append(int(label))
+        camera_id.append(int(camera[0]))
+    return camera_id, labels

+ 41 - 0
applications/fedureid/reid/exclusive_loss.py

@@ -0,0 +1,41 @@
+from __future__ import absolute_import
+
+import torch
+import torch.nn.functional as F
+from torch import nn, autograd
+
+
+class Exclusive(autograd.Function):
+    # def __init__(ctx, V):
+    #     super(Exclusive, ctx).__init__()
+    #     ctx.V = V
+
+    @staticmethod
+    def forward(ctx, inputs, targets, V):
+        ctx.V = V
+        ctx.save_for_backward(inputs, targets)
+        outputs = inputs.mm(ctx.V.t())
+        return outputs
+
+    @staticmethod
+    def backward(ctx, grad_outputs):
+        inputs, targets = ctx.saved_tensors
+        grad_inputs = grad_outputs.mm(ctx.V) if ctx.needs_input_grad[0] else None
+        for x, y in zip(inputs, targets):
+            ctx.V[y] = F.normalize( (ctx.V[y] + x) / 2, p=2, dim=0)
+        return grad_inputs, None, None
+
+
+class ExLoss(nn.Module):
+    def __init__(self, num_features, num_classes, t=1.0, weight=None):
+        super(ExLoss, self).__init__()
+        self.num_features = num_features
+        self.t = t
+        self.weight = weight
+        self.register_buffer('V', torch.zeros(num_classes, num_features))
+
+    def forward(self, inputs, targets):
+        outputs = Exclusive.apply(inputs, targets, self.V) * self.t
+        # outputs = Exclusive(self.V)(inputs, targets) * self.t
+        loss = F.cross_entropy(outputs, targets, weight=self.weight)
+        return loss, outputs

+ 7 - 0
applications/fedureid/reid/feature_extraction/__init__.py

@@ -0,0 +1,7 @@
+from __future__ import absolute_import
+
+from .cnn import extract_cnn_feature
+
+__all__ = [
+    'extract_cnn_feature',
+]

+ 30 - 0
applications/fedureid/reid/feature_extraction/cnn.py

@@ -0,0 +1,30 @@
+from __future__ import absolute_import
+from collections import OrderedDict
+
+from torch.autograd import Variable
+
+from ..utils import to_torch
+import torch
+
+
+def extract_cnn_feature(model, inputs, modules=None):
+    with torch.no_grad():
+        model.eval()
+        inputs = to_torch(inputs)
+        inputs = Variable(inputs)
+        if modules is None:
+            fcs, pool5s = model(inputs)
+            fcs = fcs.data.cpu()
+            pool5s = pool5s.data.cpu()
+            return fcs, pool5s
+        # Register forward hook for each module
+        outputs = OrderedDict()
+        handles = []
+        for m in modules:
+            outputs[id(m)] = None
+            def func(m, i, o): outputs[id(m)] = o.data.cpu()
+            handles.append(m.register_forward_hook(func))
+        model(inputs)
+        for h in handles:
+            h.remove()
+    return list(outputs.values())

+ 48 - 0
applications/fedureid/reid/models/__init__.py

@@ -0,0 +1,48 @@
+from __future__ import absolute_import
+
+from .end2end import *
+
+
+__factory = {
+    'avg_pool': End2End_AvgPooling,
+}
+
+
+def names():
+    return sorted(__factory.keys())
+
+
+def create(name, *args, **kwargs):
+    """
+    Create a model instance.
+
+    Parameters
+    ----------
+    name : str
+        Model name. Can be one of 'inception', 'resnet18', 'resnet34',
+        'resnet50', 'resnet101', and 'resnet152'.
+    pretrained : bool, optional
+        Only applied for 'resnet*' models. If True, will use ImageNet pretrained
+        model. Default: True
+    cut_at_pooling : bool, optional
+        If True, will cut the model before the last global pooling layer and
+        ignore the remaining kwargs. Default: False
+    num_features : int, optional
+        If positive, will append a Linear layer after the global pooling layer,
+        with this number of output units, followed by a BatchNorm layer.
+        Otherwise these layers will not be appended. Default: 256 for
+        'inception', 0 for 'resnet*'
+    norm : bool, optional
+        If True, will normalize the feature to be unit L2-norm for each sample.
+        Otherwise will append a ReLU layer after the above Linear layer if
+        num_features > 0. Default: False
+    dropout : float, optional
+        If positive, will append a Dropout layer with this dropout rate.
+        Default: 0
+    num_classes : int, optional
+        If positive, will append a Linear layer at the end as the classifier
+        with this number of output units. Default: 0
+    """
+    if name not in __factory:
+        raise KeyError("Unknown model:", name)
+    return __factory[name](*args, **kwargs)

+ 61 - 0
applications/fedureid/reid/models/end2end.py

@@ -0,0 +1,61 @@
+from __future__ import absolute_import
+
+from torch import nn
+from torch.autograd import Variable
+from torch.nn import functional as F
+from torch.nn import init
+import torch
+import torchvision
+import math
+
+from .resnet import *
+
+__all__ = ["End2End_AvgPooling"]
+
+
+class AvgPooling(nn.Module):
+    def __init__(self, input_feature_size, embedding_fea_size=1024, dropout=0.5):
+        super(self.__class__, self).__init__()
+
+        # embedding
+        self.embedding_fea_size = embedding_fea_size
+        self.embedding = nn.Linear(input_feature_size, embedding_fea_size)
+        self.embedding_bn = nn.BatchNorm1d(embedding_fea_size)
+        init.kaiming_normal_(self.embedding.weight, mode='fan_out')
+        init.constant_(self.embedding.bias, 0)
+        init.constant_(self.embedding_bn.weight, 1)
+        init.constant_(self.embedding_bn.bias, 0)
+        self.drop = nn.Dropout(dropout)
+
+    def forward(self, inputs):
+        net = inputs.mean(dim=1)
+        eval_features = F.normalize(net, p=2, dim=1)
+        net = self.embedding(net)
+        net = self.embedding_bn(net)
+        net = F.normalize(net, p=2, dim=1)
+        net = self.drop(net)
+        return net, eval_features
+
+
+class End2End_AvgPooling(nn.Module):
+
+    def __init__(self, dropout=0, embedding_fea_size=1024, fixed_layer=True):
+        super(self.__class__, self).__init__()
+        self.CNN = resnet50(dropout=dropout, fixed_layer=fixed_layer)
+        self.avg_pooling = AvgPooling(input_feature_size=2048, embedding_fea_size=embedding_fea_size, dropout=dropout)
+
+    def forward(self, x):
+        assert len(x.data.shape) == 5
+        # reshape (batch, samples, ...) ==> (batch * samples, ...)
+        oriShape = x.data.shape
+        x = x.view(-1, oriShape[2], oriShape[3], oriShape[4])
+
+        # resnet encoding
+        resnet_feature = self.CNN(x)
+
+        # reshape back into (batch, samples, ...)
+        resnet_feature = resnet_feature.view(oriShape[0], oriShape[1], -1)
+
+        # avg pooling
+        output = self.avg_pooling(resnet_feature)
+        return output

+ 56 - 0
applications/fedureid/reid/models/model.py

@@ -0,0 +1,56 @@
+from __future__ import absolute_import
+
+from torch import nn
+from torch.nn import functional as F
+from torch.nn import init
+
+from easyfl.models.model import BaseModel
+from .resnet import *
+
+__all__ = ["BUCModel"]
+
+
+class AvgPooling(nn.Module):
+    def __init__(self, input_feature_size, embedding_feature_size=2048, dropout=0.5):
+        super(self.__class__, self).__init__()
+
+        # embedding
+        self.embedding_feature_size = embedding_feature_size
+        self.embedding = nn.Linear(input_feature_size, embedding_feature_size)
+        self.embedding_bn = nn.BatchNorm1d(embedding_feature_size)
+        init.kaiming_normal_(self.embedding.weight, mode='fan_out')
+        init.constant_(self.embedding.bias, 0)
+        init.constant_(self.embedding_bn.weight, 1)
+        init.constant_(self.embedding_bn.bias, 0)
+        self.drop = nn.Dropout(dropout)
+
+    def forward(self, inputs):
+        net = inputs.mean(dim=1)
+        eval_features = F.normalize(net, p=2, dim=1)
+        net = self.embedding(net)
+        net = self.embedding_bn(net)
+        net = F.normalize(net, p=2, dim=1)
+        net = self.drop(net)
+        return net, eval_features
+
+
+class BUCModel(BaseModel):
+    def __init__(self, dropout=0.5, embedding_feature_size=2048):
+        super(self.__class__, self).__init__()
+        self.CNN = resnet50(dropout=dropout)
+        self.avg_pooling = AvgPooling(input_feature_size=2048,
+                                      embedding_feature_size=embedding_feature_size,
+                                      dropout=dropout)
+
+    def forward(self, x):
+        # resnet encoding
+        resnet_feature = self.CNN(x)
+        shape = resnet_feature.shape
+
+        # reshape back into (batch, samples, ...)
+        # samples of video frames, we only use images, so always 1.
+        resnet_feature = resnet_feature.view(shape[0], 1, -1)
+
+        # avg pooling
+        output = self.avg_pooling(resnet_feature)
+        return output

+ 113 - 0
applications/fedureid/reid/models/resnet.py

@@ -0,0 +1,113 @@
+from __future__ import absolute_import
+
+from torch import nn
+from torch.nn import functional as F
+from torch.nn import init
+import torchvision
+
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152']
+
+
+class ResNet(nn.Module):
+    __factory = {
+        18: torchvision.models.resnet18,
+        34: torchvision.models.resnet34,
+        50: torchvision.models.resnet50,
+        101: torchvision.models.resnet101,
+        152: torchvision.models.resnet152,
+    }
+
+    def __init__(self, depth, pretrained=True, cut_at_pooling=False,
+                 num_features=0, norm=False, dropout=0, num_classes=0):
+        super(ResNet, self).__init__()
+
+        self.depth = depth
+        self.pretrained = pretrained
+        self.cut_at_pooling = cut_at_pooling
+
+        # Construct base (pretrained) resnet
+        if depth not in ResNet.__factory:
+            raise KeyError("Unsupported depth:", depth)
+
+        self.base = ResNet.__factory[depth](pretrained=pretrained)
+
+        if not self.cut_at_pooling:
+            self.num_features = num_features
+            self.norm = norm
+            self.dropout = dropout
+            self.has_embedding = num_features > 0
+            self.num_classes = num_classes
+
+            out_planes = self.base.fc.in_features
+
+            # Append new layers
+            if self.has_embedding:
+                self.feat = nn.Linear(out_planes, self.num_features)
+                self.feat_bn = nn.BatchNorm1d(self.num_features)
+                init.kaiming_normal(self.feat.weight, mode='fan_out')
+                init.constant(self.feat.bias, 0)
+                init.constant(self.feat_bn.weight, 1)
+                init.constant(self.feat_bn.bias, 0)
+            else:
+                # Change the num_features to CNN output channels
+                self.num_features = out_planes
+            if self.dropout > 0:
+                self.drop = nn.Dropout(self.dropout)
+            if self.num_classes > 0:
+                self.classifier = nn.Linear(self.num_features, self.num_classes)
+                init.normal(self.classifier.weight, std=0.001)
+                init.constant(self.classifier.bias, 0)
+
+        if not self.pretrained:
+            self.reset_params()
+
+    def forward(self, x):
+
+        for name, module in self.base._modules.items():
+            if name == 'avgpool':
+                break
+            x = module(x)
+
+        if self.cut_at_pooling:
+            return x
+
+        x = F.avg_pool2d(x, x.size()[2:])
+        x = x.view(x.size(0), -1)
+
+        return x
+
+    def reset_params(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                init.kaiming_normal(m.weight, mode='fan_out')
+                if m.bias is not None:
+                    init.constant(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                init.constant(m.weight, 1)
+                init.constant(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                init.normal(m.weight, std=0.001)
+                if m.bias is not None:
+                    init.constant(m.bias, 0)
+
+
+def resnet18(**kwargs):
+    return ResNet(18, **kwargs)
+
+
+def resnet34(**kwargs):
+    return ResNet(34, **kwargs)
+
+
+def resnet50(**kwargs):
+    return ResNet(50, **kwargs)
+
+
+def resnet101(**kwargs):
+    return ResNet(101, **kwargs)
+
+
+def resnet152(**kwargs):
+    return ResNet(152, **kwargs)

+ 87 - 0
applications/fedureid/reid/trainers.py

@@ -0,0 +1,87 @@
+from __future__ import print_function, absolute_import
+
+import logging
+import time
+
+from torch.autograd import Variable
+
+from .evaluation_metrics import accuracy
+from .utils.meters import AverageMeter
+
+logger = logging.getLogger(__name__)
+
+
+class BaseTrainer(object):
+    def __init__(self, model, criterion, device, fixed_layer=False):
+        super(BaseTrainer, self).__init__()
+        self.model = model
+        self.criterion = criterion
+        self.fixed_layer = fixed_layer
+        self.device = device
+
+    def train(self, epoch, data_loader, optimizer, print_freq=1):
+        self.model.train()
+
+        batch_time = AverageMeter()
+        data_time = AverageMeter()
+        losses = AverageMeter()
+        precisions = AverageMeter()
+
+        stop_local_training = False
+        precision_avg = []
+
+        end = time.time()
+        for i, inputs in enumerate(data_loader):
+            data_time.update(time.time() - end)
+
+            inputs, targets = self._parse_data(inputs)
+            loss, prec1 = self._forward(inputs, targets)
+
+            losses.update(loss.item(), targets.size(0))
+            precisions.update(prec1, targets.size(0))
+
+            optimizer.zero_grad()
+            loss.backward()
+
+            optimizer.step()
+
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if (i + 1) % print_freq == 0:
+                logger.info('Epoch: [{}][{}/{}]\t'
+                            'Time {:.3f} ({:.3f})\t'
+                            'Data {:.3f} ({:.3f})\t'
+                            'Loss {:.3f} ({:.3f})\t'
+                            'Prec {:.2%} ({:.2%})\t'
+                            .format(epoch, i + 1, len(data_loader),
+                                    batch_time.val, batch_time.avg,
+                                    data_time.val, data_time.avg,
+                                    losses.val, losses.avg,
+                                    precisions.val, precisions.avg))
+            precision_avg.append(precisions.avg)
+            if precisions.val == 1 or precisions.avg > 0.95:
+                stop_local_training = True
+        return stop_local_training
+
+    def _parse_data(self, inputs):
+        raise NotImplementedError
+
+    def _forward(self, inputs, targets):
+        raise NotImplementedError
+
+
+class Trainer(BaseTrainer):
+    def _parse_data(self, inputs):
+        x, y = inputs
+        inputs = Variable(x.to(self.device), requires_grad=False)
+        targets = Variable(y.to(self.device))
+        return inputs, targets
+
+    def _forward(self, inputs, targets):
+        outputs, _ = self.model(inputs)
+        outputs = outputs.to(self.device)
+        loss, outputs = self.criterion(outputs, targets)
+        prec, = accuracy(outputs.data, targets.data)
+        prec = prec[0]
+        return loss, prec

+ 21 - 0
applications/fedureid/reid/utils/__init__.py

@@ -0,0 +1,21 @@
+from __future__ import absolute_import
+
+import torch
+
+
+def to_numpy(tensor):
+    if torch.is_tensor(tensor):
+        return tensor.cpu().numpy()
+    elif type(tensor).__module__ != 'numpy':
+        raise ValueError("Cannot convert {} to numpy array"
+                         .format(type(tensor)))
+    return tensor
+
+
+def to_torch(ndarray):
+    if type(ndarray).__module__ == 'numpy':
+        return torch.from_numpy(ndarray)
+    elif not torch.is_tensor(ndarray):
+        raise ValueError("Cannot convert {} to torch tensor"
+                         .format(type(ndarray)))
+    return ndarray

+ 23 - 0
applications/fedureid/reid/utils/meters.py

@@ -0,0 +1,23 @@
+from __future__ import absolute_import
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count

+ 11 - 0
applications/fedureid/reid/utils/osutils.py

@@ -0,0 +1,11 @@
+from __future__ import absolute_import
+import os
+import errno
+
+
+def mkdir_if_missing(dir_path):
+    try:
+        os.makedirs(dir_path)
+    except OSError as e:
+        if e.errno != errno.EEXIST:
+            raise

+ 60 - 0
applications/fedureid/reid/utils/serialization.py

@@ -0,0 +1,60 @@
+from __future__ import print_function, absolute_import
+import json
+import os.path as osp
+import shutil
+
+import torch
+from torch.nn import Parameter
+
+from .osutils import mkdir_if_missing
+
+
+def read_json(fpath):
+    with open(fpath, 'r') as f:
+        obj = json.load(f)
+    return obj
+
+
+def write_json(obj, fpath):
+    mkdir_if_missing(osp.dirname(fpath))
+    with open(fpath, 'w') as f:
+        json.dump(obj, f, indent=4, separators=(',', ': '))
+
+
+def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
+    mkdir_if_missing(osp.dirname(fpath))
+    torch.save(state, fpath)
+    if is_best:
+        shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))
+
+
+def load_checkpoint(fpath):
+    if osp.isfile(fpath):
+        checkpoint = torch.load(fpath)
+        print("=> Loaded checkpoint '{}'".format(fpath))
+        return checkpoint
+    else:
+        raise ValueError("=> No checkpoint found at '{}'".format(fpath))
+
+
+def copy_state_dict(state_dict, model, strip=None):
+    tgt_state = model.state_dict()
+    copied_names = set()
+    for name, param in state_dict.items():
+        if strip is not None and name.startswith(strip):
+            name = name[len(strip):]
+        if name not in tgt_state:
+            continue
+        if isinstance(param, Parameter):
+            param = param.data
+        if param.size() != tgt_state[name].size():
+            print('mismatch:', name, param.size(), tgt_state[name].size())
+            continue
+        tgt_state[name].copy_(param)
+        copied_names.add(name)
+
+    missing = set(tgt_state.keys()) - copied_names
+    if len(missing) > 0:
+        print("missing keys in state_dict:", missing)
+
+    return model

+ 1 - 0
applications/fedureid/reid/utils/transform/__init__.py

@@ -0,0 +1 @@
+from __future__ import absolute_import

+ 66 - 0
applications/fedureid/reid/utils/transform/transforms.py

@@ -0,0 +1,66 @@
+from __future__ import absolute_import
+
+import math
+import random
+
+from PIL import Image
+from torchvision import transforms
+
+
+class RectScale(object):
+    def __init__(self, height, width, interpolation=Image.BILINEAR):
+        self.height = height
+        self.width = width
+        self.interpolation = interpolation
+
+    def __call__(self, img):
+        w, h = img.size
+        if h == self.height and w == self.width:
+            return img
+        return img.resize((self.width, self.height), self.interpolation)
+
+
+class RandomSizedRectCrop(object):
+    def __init__(self, height, width, interpolation=Image.BILINEAR):
+        self.height = height
+        self.width = width
+        self.interpolation = interpolation
+
+    def __call__(self, img):
+        for attempt in range(10):
+            area = img.size[0] * img.size[1]
+            target_area = random.uniform(0.64, 1.0) * area
+            aspect_ratio = random.uniform(2, 3)
+
+            h = int(round(math.sqrt(target_area * aspect_ratio)))
+            w = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if w <= img.size[0] and h <= img.size[1]:
+                x1 = random.randint(0, img.size[0] - w)
+                y1 = random.randint(0, img.size[1] - h)
+
+                img = img.crop((x1, y1, x1 + w, y1 + h))
+                assert (img.size == (w, h))
+
+                return img.resize((self.width, self.height), self.interpolation)
+
+        # Fallback
+        scale = RectScale(self.height, self.width,
+                          interpolation=self.interpolation)
+        return scale(img)
+
+
+normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+TRANSFORM_TRAIN_LIST = transforms.Compose([
+    RandomSizedRectCrop(256, 128),
+    transforms.RandomHorizontalFlip(),
+    transforms.ToTensor(),
+    normalizer,
+])
+
+TRANSFORM_VAL_LIST = transformer = transforms.Compose([
+    RectScale(256, 128),
+    transforms.ToTensor(),
+    normalizer,
+])