Parcourir la source

[Feature] Federated person re-identification (#2)

Zhuang Weiming il y a 2 ans
Parent
commit
1e5ae55ee4

+ 5 - 1
README.md

@@ -56,7 +56,11 @@ For more advanced usage, we provide a list of tutorials on:
 
 ## Projects & Papers
 
-The following publications are developed using EasyFL.
+We have released the source code for the following papers under the `applications` folder:
+
+- FedReID: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedreid) for "Performance Optimization for Federated Person Re-identification via Benchmark Analysis", _ACMMM'2020_.
+
+The following publications are developed using EasyFL:
 
 - Divergence-aware Federated Self-Supervised Learning, _ICLR'2022_. [[paper]](https://openreview.net/forum?id=oVE1z8NlNe)
 - Collaborative Unsupervised Visual Representation Learning From Decentralized Data, _ICCV'2021_. [[paper]](https://openaccess.thecvf.com/content/ICCV2021/html/Zhuang_Collaborative_Unsupervised_Visual_Representation_Learning_From_Decentralized_Data_ICCV_2021_paper.html)

+ 60 - 0
applications/fedreid/README.md

@@ -0,0 +1,60 @@
+# Federated Person Re-identification (FedReID)
+
+Personal re-identification is an important computer vision task, but its development is constrained by the increasing privacy concerns. Federated learning is a privacy-preserving machine learning technique that learns a shared model across decentralized clients. We implement federated learning to person re-identification (**FedReID**) and optimize its performance affected by **statistical heterogeneity** in the real-world scenarios. 
+
+This is code for ACMMM 2020 oral paper - **[Performance Optimization for Federated Person Re-identification via Benchmark Analysis](https://arxiv.org/abs/2008.11560)**
+
+Algorithm: Federated Partial Averaging (FedPav)
+
+<img src="images/fedpav.png" width="700">
+
+## Prerequisite
+
+It requires the following Python libraries:
+```
+torch
+torchvision
+easyfl
+```
+
+Please refer to the [documentation](https://easyfl.readthedocs.io/en/latest/get_started.html#installation) to install `easyfl`.
+
+## Datasets
+
+**We use 9 popular ReID datasets for the benchmark.**
+<img src="images/datasets.png" width="700">
+
+> **🎉 We are now releasing the processed datasets.** (April, 2022)
+>
+> Please [email us](weiming001@e.ntu.edu.sg) to request for the datasets with:
+> 1. A short self-introduction.
+> 2. The purposes of using these datasets.
+>
+> *⚠️ Further distribution of the datasets are prohibited.*
+
+## Run the experiments
+
+Put the processed datasets in `data_dir` and run the experiments with the following scripts.
+
+```
+python main.py --data_dir ${data_dir}
+```
+
+You can refer to the `main.py` to run experiments with more options and configurations.
+
+> Note: you can run experiments with multiple GPUs by setting `--gpu`. The default implementation supports running with multiple GPUs in a _slurm cluster_. You may need to modify `main.py` to use `multiprocess`.
+
+You may refer to the [original implementation](https://github.com/cap-ntu/FedReID) for the optimization methods: knowledge distillation and weight adjustment.
+
+    
+## Citation
+```
+@inproceedings{zhuang2020performance,
+  title={Performance Optimization of Federated Person Re-identification via Benchmark Analysis},
+  author={Zhuang, Weiming and Wen, Yonggang and Zhang, Xuesen and Gan, Xin and Yin, Daiying and Zhou, Dongzhan and Zhang, Shuai and Yi, Shuai},
+  booktitle={Proceedings of the 28th ACM International Conference on Multimedia},
+  pages={955--963},
+  year={2020}
+}
+```
+

+ 0 - 0
applications/fedreid/__init__.py


+ 139 - 0
applications/fedreid/client.py

@@ -0,0 +1,139 @@
+import logging
+import os
+import time
+
+import numpy as np
+import torch
+import torch._utils
+import torch.nn as nn
+import torch.optim as optim
+
+from easyfl.client.base import BaseClient
+from easyfl.distributed.distributed import CPU
+from easyfl.pb import common_pb2 as common_pb
+from easyfl.pb import server_service_pb2 as server_pb
+from easyfl.protocol import codec
+from easyfl.tracking import metric
+from evaluate import test_evaluate, extract_feature
+from model import get_classifier
+
+logger = logging.getLogger(__name__)
+
+
+class FedReIDClient(BaseClient):
+    def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
+        super(FedReIDClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
+        self.classifier = get_classifier(len(self.train_data.classes[cid])).to(device)
+        self.gallery_cam = None
+        self.gallery_label = None
+        self.query_cam = None
+        self.query_label = None
+        self.test_gallery_loader = None
+        self.test_query_loader = None
+
+    def train(self, conf, device=CPU):
+        self.model.classifier.classifier = self.classifier.to(device)
+        start_time = time.time()
+        loss_fn, optimizer = self.pretrain_setup(conf, device)
+        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
+        epoch_loss = []
+        for i in range(conf.local_epoch):
+            batch_loss = []
+            for batched_x, batched_y in self.train_loader:
+                x, y = batched_x.to(device), batched_y.to(device)
+                optimizer.zero_grad()
+                out = self.model(x)
+                loss = loss_fn(out, y)
+                loss.backward()
+                optimizer.step()
+                batch_loss.append(loss.item())
+            scheduler.step()
+            current_epoch_loss = sum(batch_loss) / len(batch_loss)
+            epoch_loss.append(float(current_epoch_loss))
+            logger.info("Client {}, local epoch: {}, loss: {}".format(self.cid, i, current_epoch_loss))
+        self.current_round_time = time.time() - start_time
+        self.track(metric.TRAIN_TIME, self.current_round_time)
+        self.track(metric.TRAIN_LOSS, epoch_loss)
+        self.classifier = self.model.classifier.classifier
+        self.model.classifier.classifier = nn.Sequential()
+
+    def load_optimizer(self, conf):
+        ignored_params = list(map(id, self.model.classifier.parameters()))
+        base_params = filter(lambda p: id(p) not in ignored_params, self.model.parameters())
+        optimizer_ft = optim.SGD([
+            {'params': base_params, 'lr': 0.1 * conf.optimizer.lr},
+            {'params': self.model.classifier.parameters(), 'lr': conf.optimizer.lr}
+        ], weight_decay=5e-4, momentum=conf.optimizer.momentum, nesterov=True)
+        return optimizer_ft
+
+    def test(self, conf, device=CPU):
+        self.model = self.model.eval()
+        self.model = self.model.to(device)
+        gallery_id = '{}_{}'.format(self.cid, 'gallery')
+        query_id = '{}_{}'.format(self.cid, 'query')
+        if self.test_gallery_loader is None or self.test_query_loader is None:
+            self.test_gallery_loader = self.test_data.loader(batch_size=128,
+                                                             client_id=gallery_id,
+                                                             shuffle=False,
+                                                             seed=conf.seed)
+            self.test_query_loader = self.test_data.loader(batch_size=128,
+                                                           client_id=query_id,
+                                                           shuffle=False,
+                                                           seed=conf.seed)
+            gallery_path = [(self.test_data.data[gallery_id]['x'][i],
+                             self.test_data.data[gallery_id]['y'][i])
+                            for i in range(len(self.test_data.data[gallery_id]['y']))]
+            query_path = [(self.test_data.data[query_id]['x'][i],
+                           self.test_data.data[query_id]['y'][i])
+                          for i in range(len(self.test_data.data[query_id]['y']))]
+            gallery_cam, gallery_label = self._get_id(gallery_path)
+            self.gallery_cam = gallery_cam
+            self.gallery_label = gallery_label
+            query_cam, query_label = self._get_id(query_path)
+            self.query_cam = query_cam
+            self.query_label = query_label
+        with torch.no_grad():
+            gallery_feature = extract_feature(self.model,
+                                              self.test_gallery_loader,
+                                              device)
+            query_feature = extract_feature(self.model,
+                                            self.test_query_loader,
+                                            device)
+
+        result = {
+            'gallery_f': gallery_feature.numpy(),
+            'gallery_label': np.array([self.gallery_label]),
+            'gallery_cam': np.array([self.gallery_cam]),
+            'query_f': query_feature.numpy(),
+            'query_label': np.array([self.query_label]),
+            'query_cam': np.array([self.query_cam]),
+        }
+
+        logger.info("Evaluating {}".format(self.cid))
+        rank1, rank5, rank10, mAP = test_evaluate(result, device)
+        logger.info("Dataset: {} Rank@1:{:.2%} Rank@5:{:.2%} Rank@10:{:.2%} mAP:{:.2%}".format(
+            self.cid, rank1, rank5, rank10, mAP))
+        self._upload_holder = server_pb.UploadContent(
+            data=codec.marshal(server_pb.Performance(accuracy=rank1, loss=0)),  # loss not applicable
+            type=common_pb.DATA_TYPE_PERFORMANCE,
+            data_size=len(self.query_label),
+        )
+
+    def _get_id(self, img_path):
+        camera_id = []
+        labels = []
+        for p, v in img_path:
+            filename = os.path.basename(p)
+            if filename[:3] != 'cam':
+                label = filename[0:4]
+                camera = filename.split('c')[1]
+                camera = camera.split('s')[0]
+            else:
+                label = filename.split('_')[2]
+                camera = filename.split('_')[1]
+            if label[0:2] == '-1':
+                labels.append(-1)
+            else:
+                labels.append(int(label))
+            camera_id.append(int(camera[0]))
+        return camera_id, labels

+ 23 - 0
applications/fedreid/config.yaml

@@ -0,0 +1,23 @@
+tid: "fedreid"
+server:
+  test_all: False
+  clients_per_round: 9
+  test_every: 10
+  rounds: 300
+  batch_size: 32
+  aggregation_content: "parameters"
+resource_heterogeneous:
+  simulate: False
+  pre_profile: False
+  total_time: 0
+  grouping_strategy: "greedy"
+client:
+  local_epoch: 1
+  track: False
+  batch_size: 32
+  optimizer:
+    type: "SGD"
+    lr: 0.05
+    momentum: 0.9
+test_mode: "test_in_client"
+test_method: "average"

+ 56 - 0
applications/fedreid/dataset.py

@@ -0,0 +1,56 @@
+import os
+
+from torchvision import transforms
+
+from easyfl.datasets import FederatedImageDataset
+
+DB_NAMES = ["MSMT17", "Duke", "Market", "cuhk03", "prid", "cuhk01", "viper", "3dpes", "ilids"]
+
+TRANSFORM_TRAIN_LIST = transforms.Compose([
+    transforms.Resize((256, 128), interpolation=3),
+    transforms.Pad(10),
+    transforms.RandomCrop((256, 128)),
+    transforms.RandomHorizontalFlip(),
+    transforms.ToTensor(),
+    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+])
+TRANSFORM_VAL_LIST = transforms.Compose([
+    transforms.Resize(size=(256, 128), interpolation=3),
+    transforms.ToTensor(),
+    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+])
+
+
+def prepare_train_data(data_dir, db_names=None):
+    if db_names is None:
+        db_names = DB_NAMES
+    client_ids = []
+    roots = []
+    for db in db_names:
+        client_ids.append(db)
+        data_path = os.path.join(data_dir, db, 'pytorch')
+        roots.append(os.path.join(data_path, 'train_all'))
+    data = FederatedImageDataset(root=roots,
+                                 simulated=True,
+                                 do_simulate=False,
+                                 transform=TRANSFORM_TRAIN_LIST,
+                                 client_ids=client_ids)
+    return data
+
+
+def prepare_test_data(data_dir, db_names=None):
+    if db_names is None:
+        db_names = DB_NAMES
+    roots = []
+    client_ids = []
+    for db in db_names:
+        test_gallery = os.path.join(data_dir, db, 'pytorch', 'gallery')
+        test_query = os.path.join(data_dir, db, 'pytorch', 'query')
+        roots.extend([test_gallery, test_query])
+        client_ids.extend(["{}_{}".format(db, "gallery"), "{}_{}".format(db, "query")])
+    data = FederatedImageDataset(root=roots,
+                                 simulated=True,
+                                 do_simulate=False,
+                                 transform=TRANSFORM_VAL_LIST,
+                                 client_ids=client_ids)
+    return data

+ 109 - 0
applications/fedreid/evaluate.py

@@ -0,0 +1,109 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+
+
+def extract_feature(model, dataloaders, device, ms=[1]):
+    features = torch.FloatTensor()
+    model = model.to(device)
+    for data in dataloaders:
+        img, label = data
+        n, c, h, w = img.size()
+        ff = torch.FloatTensor(n, 512).zero_().to(device)
+        for i in range(2):
+            if i == 1:
+                img = fliplr(img)
+            input_img = Variable(img.to(device))
+            for scale in ms:
+                if scale != 1:
+                    # bicubic is only  available in pytorch>= 1.1
+                    input_img = nn.functional.interpolate(input_img, scale_factor=scale, mode='bicubic',
+                                                          align_corners=False)
+                outputs = model(input_img)
+                ff += outputs
+        # # norm feature
+        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
+        ff = ff.div(fnorm.expand_as(ff))
+        features = torch.cat((features, ff.data.cpu()), 0)
+    return features
+
+
+def test_evaluate(result, device):
+    query_feature = torch.FloatTensor(result['query_f'])
+    query_cam = result['query_cam'][0]
+    query_label = result['query_label'][0]
+    gallery_feature = torch.FloatTensor(result['gallery_f'])
+    gallery_cam = result['gallery_cam'][0]
+    gallery_label = result['gallery_label'][0]
+    query_feature = query_feature.to(device)
+    gallery_feature = gallery_feature.to(device)
+    CMC = torch.IntTensor(len(gallery_label)).zero_()
+    ap = 0.0
+    for i in range(len(query_label)):
+        ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], query_cam[i], gallery_feature, gallery_label,
+                                   gallery_cam)
+        if CMC_tmp[0] == -1:
+            continue
+        CMC = CMC + CMC_tmp
+        ap += ap_tmp
+    CMC = CMC.float()
+    CMC = CMC / len(query_label)  # average CMC
+    return CMC[0], CMC[4], CMC[9], ap / len(query_label)
+
+
+def evaluate(qf, ql, qc, gf, gl, gc):
+    query = qf.view(-1, 1)
+    score = torch.mm(gf, query)
+    score = score.squeeze(1).cpu()
+    score = score.numpy()
+    # predict index
+    index = np.argsort(score)  # from small to large
+    index = index[::-1]
+    # good index
+    query_index = np.argwhere(gl == ql)
+    camera_index = np.argwhere(gc == qc)
+
+    good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
+    junk_index1 = np.argwhere(gl == -1)
+    junk_index2 = np.intersect1d(query_index, camera_index)
+    junk_index = np.append(junk_index2, junk_index1)  # .flatten())
+
+    CMC_tmp = compute_mAP(index, good_index, junk_index)
+    return CMC_tmp
+
+
+def compute_mAP(index, good_index, junk_index):
+    ap = 0
+    cmc = torch.IntTensor(len(index)).zero_()
+    if good_index.size == 0:  # if empty
+        cmc[0] = -1
+        return ap, cmc
+
+    # remove junk_index
+    mask = np.in1d(index, junk_index, invert=True)
+    index = index[mask]
+
+    # find good_index index
+    ngood = len(good_index)
+    mask = np.in1d(index, good_index)
+    rows_good = np.argwhere(mask == True)
+    rows_good = rows_good.flatten()
+
+    cmc[rows_good[0]:] = 1
+    for i in range(ngood):
+        d_recall = 1.0 / ngood
+        precision = (i + 1) * 1.0 / (rows_good[i] + 1)
+        if rows_good[i] != 0:
+            old_precision = i * 1.0 / rows_good[i]
+        else:
+            old_precision = 1.0
+        ap = ap + d_recall * (old_precision + precision) / 2
+
+    return ap, cmc
+
+
+def fliplr(img):
+    inv_idx = torch.arange(img.size(3) - 1, -1, -1).long()  # N x C x H x W
+    img_flip = img.index_select(3, inv_idx)
+    return img_flip

BIN
applications/fedreid/images/datasets.png


BIN
applications/fedreid/images/fedpav.png


+ 60 - 0
applications/fedreid/main.py

@@ -0,0 +1,60 @@
+import argparse
+import logging
+import os
+
+import easyfl
+from client import FedReIDClient
+from dataset import prepare_train_data, prepare_test_data
+from easyfl.distributed import slurm
+from model import Model
+
+logger = logging.getLogger(__name__)
+
+
+def run():
+    parser = argparse.ArgumentParser(description='FedReID Application')
+    parser.add_argument('--task_id', type=str, default="")
+    parser.add_argument('--data_dir', type=str, metavar='PATH', default="datasets/fedreid")
+    parser.add_argument("--datasets", nargs="+", default=None, help="list of datasets, e.g., ['ilids']")
+    parser.add_argument('--test_every', type=int, default=10)
+    parser.add_argument("--gpu", type=int, default=1, help="default number of GPU")
+    args = parser.parse_args()
+    logger.info("arguments: ", args)
+
+    train_data = prepare_train_data(args.data_dir, args.datasets)
+    test_data = prepare_test_data(args.data_dir, args.datasets)
+    easyfl.register_dataset(train_data, test_data)
+    easyfl.register_model(Model)
+    easyfl.register_client(FedReIDClient)
+
+    config = {
+        "task_id": args.task_id,
+        "gpu": args.gpu,
+        "client": {
+            "test_every": args.test_every,
+        },
+        "server": {
+            "test_every": args.test_every
+        }
+    }
+    if args.gpu > 1:
+        rank, local_rank, world_size, host_addr = slurm.setup()
+        distribute_config = {
+            "gpu": world_size,
+            "distributed": {
+                "rank": rank,
+                "local_rank": local_rank,
+                "world_size": world_size,
+                "init_method": host_addr
+            },
+        }
+        config.update(distribute_config)
+    config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yaml")
+    config = easyfl.load_config(config_file, config)
+
+    easyfl.init(config)
+    easyfl.run()
+
+
+if __name__ == '__main__':
+    run()

+ 103 - 0
applications/fedreid/model.py

@@ -0,0 +1,103 @@
+import torch.nn as nn
+from torch.nn import init
+from torchvision import models
+
+from easyfl.models import BaseModel
+
+
+def weights_init_kaiming(m):
+    classname = m.__class__.__name__
+    if classname.find('Conv') != -1:
+        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')  # For old pytorch, you may use kaiming_normal.
+    elif classname.find('Linear') != -1:
+        init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
+        init.constant_(m.bias.data, 0.0)
+    elif classname.find('BatchNorm1d') != -1:
+        init.normal_(m.weight.data, 1.0, 0.02)
+        init.constant_(m.bias.data, 0.0)
+
+
+def weights_init_classifier(m):
+    classname = m.__class__.__name__
+    if classname.find('Linear') != -1:
+        init.normal_(m.weight.data, std=0.001)
+        init.constant_(m.bias.data, 0.0)
+
+
+class ClassBlock(nn.Module):
+    def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True,
+                 return_f=False):
+        super(ClassBlock, self).__init__()
+        self.return_f = return_f
+        add_block = []
+        if linear:
+            add_block += [nn.Linear(input_dim, num_bottleneck)]
+        else:
+            num_bottleneck = input_dim
+        if bnorm:
+            add_block += [nn.BatchNorm1d(num_bottleneck)]
+        if relu:
+            add_block += [nn.LeakyReLU(0.1)]
+        if droprate > 0:
+            add_block += [nn.Dropout(p=droprate)]
+        add_block = nn.Sequential(*add_block)
+        add_block.apply(weights_init_kaiming)
+
+        classifier = []
+        classifier += [nn.Linear(num_bottleneck, class_num)]
+        classifier = nn.Sequential(*classifier)
+        classifier.apply(weights_init_classifier)
+
+        self.add_block = add_block
+        self.classifier = classifier
+
+    def forward(self, x):
+        x = self.add_block(x)
+        if self.return_f:
+            f = x
+            x = self.classifier(x)
+            return x, f
+        else:
+            x = self.classifier(x)
+            return x
+
+
+# Define the ResNet50-based Model
+class Model(BaseModel):
+
+    def __init__(self, class_num=0, droprate=0.5, stride=2):
+        super(Model, self).__init__()
+        model_ft = models.resnet50(pretrained=True)
+        self.class_num = class_num
+        if stride == 1:
+            model_ft.layer4[0].downsample[0].stride = (1, 1)
+            model_ft.layer4[0].conv2.stride = (1, 1)
+        model_ft.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.model = model_ft
+        if self.class_num != 0:
+            self.classifier = ClassBlock(2048, class_num, droprate)
+        else:
+            self.classifier = ClassBlock(2048, 10, droprate)  # 10 is not effective because classifier is replaced below
+            self.classifier.classifier = nn.Sequential()
+
+    def forward(self, x):
+        x = self.model.conv1(x)
+        x = self.model.bn1(x)
+        x = self.model.relu(x)
+        x = self.model.maxpool(x)
+        x = self.model.layer1(x)
+        x = self.model.layer2(x)
+        x = self.model.layer3(x)
+        x = self.model.layer4(x)
+        x = self.model.avgpool(x)
+        x = x.view(x.size(0), x.size(1))
+        x = self.classifier(x)
+        return x
+
+
+def get_classifier(class_num, num_bottleneck=512):
+    classifier = []
+    classifier += [nn.Linear(num_bottleneck, class_num)]
+    classifier = nn.Sequential(*classifier)
+    classifier.apply(weights_init_classifier)
+    return classifier

+ 9 - 0
applications/fedreid/remote_client.py

@@ -0,0 +1,9 @@
+import easyfl
+from client import FedReIDClient
+from dataset import prepare_train_data, prepare_test_data, DB_NAMES
+from model import Model
+
+train_data = prepare_train_data(DB_NAMES)
+test_data = prepare_test_data(DB_NAMES)
+
+easyfl.start_remote_client(train_data=train_data, test_data=test_data, client=FedReIDClient, model=Model)

+ 4 - 0
applications/fedreid/remote_server.py

@@ -0,0 +1,4 @@
+import easyfl
+from model import Model
+
+easyfl.start_remote_server(model=Model)

+ 12 - 2
docs/en/projects.md

@@ -1,7 +1,17 @@
 # Projects based on EasyFL
 
-We have built several projects based on EasyFL and published four papers in top-tier conferences and journals. 
-We list them as examples of how to extend EasyFL for your projects.
+We have been doing research on federated learning for several years and published [several papers](https://weiming.me/#publications) in top-tier conferences and journals. EasyFL is developed based on deep insights from our research. It further facilitated us built other federated learning several projects.
+
+## Applications
+
+We have released the following implementations of federated learning applications:
+
+- FedReID: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedreid) for "Performance Optimization for Federated Person Re-identification via Benchmark Analysis", _ACMMM'2020_.
+
+
+## Papers
+
+The following are the projects and papers built on EasyFL:
 
 - EasyFL: A Low-code Federated Learning Platform For Dummies, _IEEE Internet-of-Things Journal_. [[paper]](https://arxiv.org/abs/2105.07603)
 - Divergence-aware Federated Self-Supervised Learning, _ICLR'2022_. [[paper]](https://openreview.net/forum?id=oVE1z8NlNe)