client.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import logging
  2. import os
  3. import time
  4. import numpy as np
  5. import torch
  6. import torch._utils
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. from easyfl.client.base import BaseClient
  10. from easyfl.distributed.distributed import CPU
  11. from easyfl.pb import common_pb2 as common_pb
  12. from easyfl.pb import server_service_pb2 as server_pb
  13. from easyfl.protocol import codec
  14. from easyfl.tracking import metric
  15. from evaluate import test_evaluate, extract_feature
  16. from model import get_classifier
  17. logger = logging.getLogger(__name__)
  18. class FedReIDClient(BaseClient):
  19. def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
  20. super(FedReIDClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
  21. self.classifier = get_classifier(len(self.train_data.classes[cid])).to(device)
  22. self.gallery_cam = None
  23. self.gallery_label = None
  24. self.query_cam = None
  25. self.query_label = None
  26. self.test_gallery_loader = None
  27. self.test_query_loader = None
  28. def train(self, conf, device=CPU):
  29. self.model.classifier.classifier = self.classifier.to(device)
  30. start_time = time.time()
  31. loss_fn, optimizer = self.pretrain_setup(conf, device)
  32. scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
  33. epoch_loss = []
  34. for i in range(conf.local_epoch):
  35. batch_loss = []
  36. for batched_x, batched_y in self.train_loader:
  37. x, y = batched_x.to(device), batched_y.to(device)
  38. optimizer.zero_grad()
  39. out = self.model(x)
  40. loss = loss_fn(out, y)
  41. loss.backward()
  42. optimizer.step()
  43. batch_loss.append(loss.item())
  44. scheduler.step()
  45. current_epoch_loss = sum(batch_loss) / len(batch_loss)
  46. epoch_loss.append(float(current_epoch_loss))
  47. logger.info("Client {}, local epoch: {}, loss: {}".format(self.cid, i, current_epoch_loss))
  48. self.current_round_time = time.time() - start_time
  49. self.track(metric.TRAIN_TIME, self.current_round_time)
  50. self.track(metric.TRAIN_LOSS, epoch_loss)
  51. self.classifier = self.model.classifier.classifier
  52. self.model.classifier.classifier = nn.Sequential()
  53. def load_optimizer(self, conf):
  54. ignored_params = list(map(id, self.model.classifier.parameters()))
  55. base_params = filter(lambda p: id(p) not in ignored_params, self.model.parameters())
  56. optimizer_ft = optim.SGD([
  57. {'params': base_params, 'lr': 0.1 * conf.optimizer.lr},
  58. {'params': self.model.classifier.parameters(), 'lr': conf.optimizer.lr}
  59. ], weight_decay=5e-4, momentum=conf.optimizer.momentum, nesterov=True)
  60. return optimizer_ft
  61. def test(self, conf, device=CPU):
  62. self.model = self.model.eval()
  63. self.model = self.model.to(device)
  64. gallery_id = '{}_{}'.format(self.cid, 'gallery')
  65. query_id = '{}_{}'.format(self.cid, 'query')
  66. if self.test_gallery_loader is None or self.test_query_loader is None:
  67. self.test_gallery_loader = self.test_data.loader(batch_size=128,
  68. client_id=gallery_id,
  69. shuffle=False,
  70. seed=conf.seed)
  71. self.test_query_loader = self.test_data.loader(batch_size=128,
  72. client_id=query_id,
  73. shuffle=False,
  74. seed=conf.seed)
  75. gallery_path = [(self.test_data.data[gallery_id]['x'][i],
  76. self.test_data.data[gallery_id]['y'][i])
  77. for i in range(len(self.test_data.data[gallery_id]['y']))]
  78. query_path = [(self.test_data.data[query_id]['x'][i],
  79. self.test_data.data[query_id]['y'][i])
  80. for i in range(len(self.test_data.data[query_id]['y']))]
  81. gallery_cam, gallery_label = self._get_id(gallery_path)
  82. self.gallery_cam = gallery_cam
  83. self.gallery_label = gallery_label
  84. query_cam, query_label = self._get_id(query_path)
  85. self.query_cam = query_cam
  86. self.query_label = query_label
  87. with torch.no_grad():
  88. gallery_feature = extract_feature(self.model,
  89. self.test_gallery_loader,
  90. device)
  91. query_feature = extract_feature(self.model,
  92. self.test_query_loader,
  93. device)
  94. result = {
  95. 'gallery_f': gallery_feature.numpy(),
  96. 'gallery_label': np.array([self.gallery_label]),
  97. 'gallery_cam': np.array([self.gallery_cam]),
  98. 'query_f': query_feature.numpy(),
  99. 'query_label': np.array([self.query_label]),
  100. 'query_cam': np.array([self.query_cam]),
  101. }
  102. logger.info("Evaluating {}".format(self.cid))
  103. rank1, rank5, rank10, mAP = test_evaluate(result, device)
  104. logger.info("Dataset: {} Rank@1:{:.2%} Rank@5:{:.2%} Rank@10:{:.2%} mAP:{:.2%}".format(
  105. self.cid, rank1, rank5, rank10, mAP))
  106. self._upload_holder = server_pb.UploadContent(
  107. data=codec.marshal(server_pb.Performance(accuracy=rank1, loss=0)), # loss not applicable
  108. type=common_pb.DATA_TYPE_PERFORMANCE,
  109. data_size=len(self.query_label),
  110. )
  111. def _get_id(self, img_path):
  112. camera_id = []
  113. labels = []
  114. for p, v in img_path:
  115. filename = os.path.basename(p)
  116. if filename[:3] != 'cam':
  117. label = filename[0:4]
  118. camera = filename.split('c')[1]
  119. camera = camera.split('s')[0]
  120. else:
  121. label = filename.split('_')[2]
  122. camera = filename.split('_')[1]
  123. if label[0:2] == '-1':
  124. labels.append(-1)
  125. else:
  126. labels.append(int(label))
  127. camera_id.append(int(camera[0]))
  128. return camera_id, labels