123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- import logging
- import os
- import time
- import numpy as np
- import torch
- import torch._utils
- import torch.nn as nn
- import torch.optim as optim
- from easyfl.client.base import BaseClient
- from easyfl.distributed.distributed import CPU
- from easyfl.pb import common_pb2 as common_pb
- from easyfl.pb import server_service_pb2 as server_pb
- from easyfl.protocol import codec
- from easyfl.tracking import metric
- from evaluate import test_evaluate, extract_feature
- from model import get_classifier
- logger = logging.getLogger(__name__)
- class FedReIDClient(BaseClient):
- def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
- super(FedReIDClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
- self.classifier = get_classifier(len(self.train_data.classes[cid])).to(device)
- self.gallery_cam = None
- self.gallery_label = None
- self.query_cam = None
- self.query_label = None
- self.test_gallery_loader = None
- self.test_query_loader = None
- def train(self, conf, device=CPU):
- self.model.classifier.classifier = self.classifier.to(device)
- start_time = time.time()
- loss_fn, optimizer = self.pretrain_setup(conf, device)
- scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
- epoch_loss = []
- for i in range(conf.local_epoch):
- batch_loss = []
- for batched_x, batched_y in self.train_loader:
- x, y = batched_x.to(device), batched_y.to(device)
- optimizer.zero_grad()
- out = self.model(x)
- loss = loss_fn(out, y)
- loss.backward()
- optimizer.step()
- batch_loss.append(loss.item())
- scheduler.step()
- current_epoch_loss = sum(batch_loss) / len(batch_loss)
- epoch_loss.append(float(current_epoch_loss))
- logger.info("Client {}, local epoch: {}, loss: {}".format(self.cid, i, current_epoch_loss))
- self.current_round_time = time.time() - start_time
- self.track(metric.TRAIN_TIME, self.current_round_time)
- self.track(metric.TRAIN_LOSS, epoch_loss)
- self.classifier = self.model.classifier.classifier
- self.model.classifier.classifier = nn.Sequential()
- def load_optimizer(self, conf):
- ignored_params = list(map(id, self.model.classifier.parameters()))
- base_params = filter(lambda p: id(p) not in ignored_params, self.model.parameters())
- optimizer_ft = optim.SGD([
- {'params': base_params, 'lr': 0.1 * conf.optimizer.lr},
- {'params': self.model.classifier.parameters(), 'lr': conf.optimizer.lr}
- ], weight_decay=5e-4, momentum=conf.optimizer.momentum, nesterov=True)
- return optimizer_ft
- def test(self, conf, device=CPU):
- self.model = self.model.eval()
- self.model = self.model.to(device)
- gallery_id = '{}_{}'.format(self.cid, 'gallery')
- query_id = '{}_{}'.format(self.cid, 'query')
- if self.test_gallery_loader is None or self.test_query_loader is None:
- self.test_gallery_loader = self.test_data.loader(batch_size=128,
- client_id=gallery_id,
- shuffle=False,
- seed=conf.seed)
- self.test_query_loader = self.test_data.loader(batch_size=128,
- client_id=query_id,
- shuffle=False,
- seed=conf.seed)
- gallery_path = [(self.test_data.data[gallery_id]['x'][i],
- self.test_data.data[gallery_id]['y'][i])
- for i in range(len(self.test_data.data[gallery_id]['y']))]
- query_path = [(self.test_data.data[query_id]['x'][i],
- self.test_data.data[query_id]['y'][i])
- for i in range(len(self.test_data.data[query_id]['y']))]
- gallery_cam, gallery_label = self._get_id(gallery_path)
- self.gallery_cam = gallery_cam
- self.gallery_label = gallery_label
- query_cam, query_label = self._get_id(query_path)
- self.query_cam = query_cam
- self.query_label = query_label
- with torch.no_grad():
- gallery_feature = extract_feature(self.model,
- self.test_gallery_loader,
- device)
- query_feature = extract_feature(self.model,
- self.test_query_loader,
- device)
- result = {
- 'gallery_f': gallery_feature.numpy(),
- 'gallery_label': np.array([self.gallery_label]),
- 'gallery_cam': np.array([self.gallery_cam]),
- 'query_f': query_feature.numpy(),
- 'query_label': np.array([self.query_label]),
- 'query_cam': np.array([self.query_cam]),
- }
- logger.info("Evaluating {}".format(self.cid))
- rank1, rank5, rank10, mAP = test_evaluate(result, device)
- logger.info("Dataset: {} Rank@1:{:.2%} Rank@5:{:.2%} Rank@10:{:.2%} mAP:{:.2%}".format(
- self.cid, rank1, rank5, rank10, mAP))
- self._upload_holder = server_pb.UploadContent(
- data=codec.marshal(server_pb.Performance(accuracy=rank1, loss=0)),
- type=common_pb.DATA_TYPE_PERFORMANCE,
- data_size=len(self.query_label),
- )
- def _get_id(self, img_path):
- camera_id = []
- labels = []
- for p, v in img_path:
- filename = os.path.basename(p)
- if filename[:3] != 'cam':
- label = filename[0:4]
- camera = filename.split('c')[1]
- camera = camera.split('s')[0]
- else:
- label = filename.split('_')[2]
- camera = filename.split('_')[1]
- if label[0:2] == '-1':
- labels.append(-1)
- else:
- labels.append(int(label))
- camera_id.append(int(camera[0]))
- return camera_id, labels
|