|
@@ -0,0 +1,139 @@
|
|
|
+import logging
|
|
|
+import os
|
|
|
+import time
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torch._utils
|
|
|
+import torch.nn as nn
|
|
|
+import torch.optim as optim
|
|
|
+
|
|
|
+from easyfl.client.base import BaseClient
|
|
|
+from easyfl.distributed.distributed import CPU
|
|
|
+from easyfl.pb import common_pb2 as common_pb
|
|
|
+from easyfl.pb import server_service_pb2 as server_pb
|
|
|
+from easyfl.protocol import codec
|
|
|
+from easyfl.tracking import metric
|
|
|
+from evaluate import test_evaluate, extract_feature
|
|
|
+from model import get_classifier
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class FedReIDClient(BaseClient):
|
|
|
+ def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
|
|
|
+ super(FedReIDClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
|
|
|
+ self.classifier = get_classifier(len(self.train_data.classes[cid])).to(device)
|
|
|
+ self.gallery_cam = None
|
|
|
+ self.gallery_label = None
|
|
|
+ self.query_cam = None
|
|
|
+ self.query_label = None
|
|
|
+ self.test_gallery_loader = None
|
|
|
+ self.test_query_loader = None
|
|
|
+
|
|
|
+ def train(self, conf, device=CPU):
|
|
|
+ self.model.classifier.classifier = self.classifier.to(device)
|
|
|
+ start_time = time.time()
|
|
|
+ loss_fn, optimizer = self.pretrain_setup(conf, device)
|
|
|
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
|
|
|
+ epoch_loss = []
|
|
|
+ for i in range(conf.local_epoch):
|
|
|
+ batch_loss = []
|
|
|
+ for batched_x, batched_y in self.train_loader:
|
|
|
+ x, y = batched_x.to(device), batched_y.to(device)
|
|
|
+ optimizer.zero_grad()
|
|
|
+ out = self.model(x)
|
|
|
+ loss = loss_fn(out, y)
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ batch_loss.append(loss.item())
|
|
|
+ scheduler.step()
|
|
|
+ current_epoch_loss = sum(batch_loss) / len(batch_loss)
|
|
|
+ epoch_loss.append(float(current_epoch_loss))
|
|
|
+ logger.info("Client {}, local epoch: {}, loss: {}".format(self.cid, i, current_epoch_loss))
|
|
|
+ self.current_round_time = time.time() - start_time
|
|
|
+ self.track(metric.TRAIN_TIME, self.current_round_time)
|
|
|
+ self.track(metric.TRAIN_LOSS, epoch_loss)
|
|
|
+ self.classifier = self.model.classifier.classifier
|
|
|
+ self.model.classifier.classifier = nn.Sequential()
|
|
|
+
|
|
|
+ def load_optimizer(self, conf):
|
|
|
+ ignored_params = list(map(id, self.model.classifier.parameters()))
|
|
|
+ base_params = filter(lambda p: id(p) not in ignored_params, self.model.parameters())
|
|
|
+ optimizer_ft = optim.SGD([
|
|
|
+ {'params': base_params, 'lr': 0.1 * conf.optimizer.lr},
|
|
|
+ {'params': self.model.classifier.parameters(), 'lr': conf.optimizer.lr}
|
|
|
+ ], weight_decay=5e-4, momentum=conf.optimizer.momentum, nesterov=True)
|
|
|
+ return optimizer_ft
|
|
|
+
|
|
|
+ def test(self, conf, device=CPU):
|
|
|
+ self.model = self.model.eval()
|
|
|
+ self.model = self.model.to(device)
|
|
|
+ gallery_id = '{}_{}'.format(self.cid, 'gallery')
|
|
|
+ query_id = '{}_{}'.format(self.cid, 'query')
|
|
|
+ if self.test_gallery_loader is None or self.test_query_loader is None:
|
|
|
+ self.test_gallery_loader = self.test_data.loader(batch_size=128,
|
|
|
+ client_id=gallery_id,
|
|
|
+ shuffle=False,
|
|
|
+ seed=conf.seed)
|
|
|
+ self.test_query_loader = self.test_data.loader(batch_size=128,
|
|
|
+ client_id=query_id,
|
|
|
+ shuffle=False,
|
|
|
+ seed=conf.seed)
|
|
|
+ gallery_path = [(self.test_data.data[gallery_id]['x'][i],
|
|
|
+ self.test_data.data[gallery_id]['y'][i])
|
|
|
+ for i in range(len(self.test_data.data[gallery_id]['y']))]
|
|
|
+ query_path = [(self.test_data.data[query_id]['x'][i],
|
|
|
+ self.test_data.data[query_id]['y'][i])
|
|
|
+ for i in range(len(self.test_data.data[query_id]['y']))]
|
|
|
+ gallery_cam, gallery_label = self._get_id(gallery_path)
|
|
|
+ self.gallery_cam = gallery_cam
|
|
|
+ self.gallery_label = gallery_label
|
|
|
+ query_cam, query_label = self._get_id(query_path)
|
|
|
+ self.query_cam = query_cam
|
|
|
+ self.query_label = query_label
|
|
|
+ with torch.no_grad():
|
|
|
+ gallery_feature = extract_feature(self.model,
|
|
|
+ self.test_gallery_loader,
|
|
|
+ device)
|
|
|
+ query_feature = extract_feature(self.model,
|
|
|
+ self.test_query_loader,
|
|
|
+ device)
|
|
|
+
|
|
|
+ result = {
|
|
|
+ 'gallery_f': gallery_feature.numpy(),
|
|
|
+ 'gallery_label': np.array([self.gallery_label]),
|
|
|
+ 'gallery_cam': np.array([self.gallery_cam]),
|
|
|
+ 'query_f': query_feature.numpy(),
|
|
|
+ 'query_label': np.array([self.query_label]),
|
|
|
+ 'query_cam': np.array([self.query_cam]),
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.info("Evaluating {}".format(self.cid))
|
|
|
+ rank1, rank5, rank10, mAP = test_evaluate(result, device)
|
|
|
+ logger.info("Dataset: {} Rank@1:{:.2%} Rank@5:{:.2%} Rank@10:{:.2%} mAP:{:.2%}".format(
|
|
|
+ self.cid, rank1, rank5, rank10, mAP))
|
|
|
+ self._upload_holder = server_pb.UploadContent(
|
|
|
+ data=codec.marshal(server_pb.Performance(accuracy=rank1, loss=0)), # loss not applicable
|
|
|
+ type=common_pb.DATA_TYPE_PERFORMANCE,
|
|
|
+ data_size=len(self.query_label),
|
|
|
+ )
|
|
|
+
|
|
|
+ def _get_id(self, img_path):
|
|
|
+ camera_id = []
|
|
|
+ labels = []
|
|
|
+ for p, v in img_path:
|
|
|
+ filename = os.path.basename(p)
|
|
|
+ if filename[:3] != 'cam':
|
|
|
+ label = filename[0:4]
|
|
|
+ camera = filename.split('c')[1]
|
|
|
+ camera = camera.split('s')[0]
|
|
|
+ else:
|
|
|
+ label = filename.split('_')[2]
|
|
|
+ camera = filename.split('_')[1]
|
|
|
+ if label[0:2] == '-1':
|
|
|
+ labels.append(-1)
|
|
|
+ else:
|
|
|
+ labels.append(int(label))
|
|
|
+ camera_id.append(int(camera[0]))
|
|
|
+ return camera_id, labels
|