import logging import os import time import numpy as np import torch import torch._utils import torch.nn as nn import torch.optim as optim from easyfl.client.base import BaseClient from easyfl.distributed.distributed import CPU from easyfl.pb import common_pb2 as common_pb from easyfl.pb import server_service_pb2 as server_pb from easyfl.protocol import codec from easyfl.tracking import metric from evaluate import test_evaluate, extract_feature from model import get_classifier logger = logging.getLogger(__name__) class FedReIDClient(BaseClient): def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0): super(FedReIDClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time) self.classifier = get_classifier(len(self.train_data.classes[cid])).to(device) self.gallery_cam = None self.gallery_label = None self.query_cam = None self.query_label = None self.test_gallery_loader = None self.test_query_loader = None def train(self, conf, device=CPU): self.model.classifier.classifier = self.classifier.to(device) start_time = time.time() loss_fn, optimizer = self.pretrain_setup(conf, device) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) epoch_loss = [] for i in range(conf.local_epoch): batch_loss = [] for batched_x, batched_y in self.train_loader: x, y = batched_x.to(device), batched_y.to(device) optimizer.zero_grad() out = self.model(x) loss = loss_fn(out, y) loss.backward() optimizer.step() batch_loss.append(loss.item()) scheduler.step() current_epoch_loss = sum(batch_loss) / len(batch_loss) epoch_loss.append(float(current_epoch_loss)) logger.info("Client {}, local epoch: {}, loss: {}".format(self.cid, i, current_epoch_loss)) self.current_round_time = time.time() - start_time self.track(metric.TRAIN_TIME, self.current_round_time) self.track(metric.TRAIN_LOSS, epoch_loss) self.classifier = self.model.classifier.classifier self.model.classifier.classifier = nn.Sequential() def load_optimizer(self, conf): ignored_params = list(map(id, self.model.classifier.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, self.model.parameters()) optimizer_ft = optim.SGD([ {'params': base_params, 'lr': 0.1 * conf.optimizer.lr}, {'params': self.model.classifier.parameters(), 'lr': conf.optimizer.lr} ], weight_decay=5e-4, momentum=conf.optimizer.momentum, nesterov=True) return optimizer_ft def test(self, conf, device=CPU): self.model = self.model.eval() self.model = self.model.to(device) gallery_id = '{}_{}'.format(self.cid, 'gallery') query_id = '{}_{}'.format(self.cid, 'query') if self.test_gallery_loader is None or self.test_query_loader is None: self.test_gallery_loader = self.test_data.loader(batch_size=128, client_id=gallery_id, shuffle=False, seed=conf.seed) self.test_query_loader = self.test_data.loader(batch_size=128, client_id=query_id, shuffle=False, seed=conf.seed) gallery_path = [(self.test_data.data[gallery_id]['x'][i], self.test_data.data[gallery_id]['y'][i]) for i in range(len(self.test_data.data[gallery_id]['y']))] query_path = [(self.test_data.data[query_id]['x'][i], self.test_data.data[query_id]['y'][i]) for i in range(len(self.test_data.data[query_id]['y']))] gallery_cam, gallery_label = self._get_id(gallery_path) self.gallery_cam = gallery_cam self.gallery_label = gallery_label query_cam, query_label = self._get_id(query_path) self.query_cam = query_cam self.query_label = query_label with torch.no_grad(): gallery_feature = extract_feature(self.model, self.test_gallery_loader, device) query_feature = extract_feature(self.model, self.test_query_loader, device) result = { 'gallery_f': gallery_feature.numpy(), 'gallery_label': np.array([self.gallery_label]), 'gallery_cam': np.array([self.gallery_cam]), 'query_f': query_feature.numpy(), 'query_label': np.array([self.query_label]), 'query_cam': np.array([self.query_cam]), } logger.info("Evaluating {}".format(self.cid)) rank1, rank5, rank10, mAP = test_evaluate(result, device) logger.info("Dataset: {} Rank@1:{:.2%} Rank@5:{:.2%} Rank@10:{:.2%} mAP:{:.2%}".format( self.cid, rank1, rank5, rank10, mAP)) self._upload_holder = server_pb.UploadContent( data=codec.marshal(server_pb.Performance(accuracy=rank1, loss=0)), # loss not applicable type=common_pb.DATA_TYPE_PERFORMANCE, data_size=len(self.query_label), ) def _get_id(self, img_path): camera_id = [] labels = [] for p, v in img_path: filename = os.path.basename(p) if filename[:3] != 'cam': label = filename[0:4] camera = filename.split('c')[1] camera = camera.split('s')[0] else: label = filename.split('_')[2] camera = filename.split('_')[1] if label[0:2] == '-1': labels.append(-1) else: labels.append(int(label)) camera_id.append(int(camera[0])) return camera_id, labels