main.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import argparse
  2. import os
  3. import time
  4. import torch._utils
  5. import easyfl
  6. from dataset import prepare_train_data, prepare_test_data
  7. from reid.bottomup import *
  8. from reid.models.model import BUCModel
  9. from easyfl.client.base import BaseClient
  10. from easyfl.distributed import slurm
  11. from easyfl.distributed.distributed import CPU
  12. from easyfl.pb import common_pb2 as common_pb
  13. from easyfl.pb import server_service_pb2 as server_pb
  14. from easyfl.protocol import codec
  15. from easyfl.tracking import metric
  16. logger = logging.getLogger(__name__)
  17. LOCAL_TEST = "local_test"
  18. GLOBAL_TEST = "global_test"
  19. RELABEL_LOCAL = "local"
  20. RELABEL_GLOBAL = "global"
  21. class FedUReIDClient(BaseClient):
  22. def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0,
  23. is_remote=False, local_port=23000, server_addr="localhost:22999", tracker_addr="localhost:12666"):
  24. super(FedUReIDClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time,
  25. is_remote, local_port, server_addr, tracker_addr)
  26. logger.info(conf)
  27. self.conf = conf
  28. self.current_step = -1
  29. self._local_model = None # for caching local model in testing
  30. self.gallery_cam = None
  31. self.gallery_label = None
  32. self.query_cam = None
  33. self.query_label = None
  34. self.test_gallery_loader = None
  35. self.test_query_loader = None
  36. self.train_data = train_data
  37. self.test_data = test_data
  38. self.labeled_ys = self.train_data.data[self.cid]['y']
  39. self.unlabeled_ys = [i for i in range(len(self.labeled_ys))]
  40. # initialize unlabeled transform
  41. self.train_data.data[self.cid]['y'] = self.unlabeled_ys
  42. num_classes = len(np.unique(np.array(self.unlabeled_ys)))
  43. merge_percent = conf.buc.merge_percent
  44. self.nums_to_merge = int(num_classes * conf.buc.merge_percent)
  45. self.steps = int(1 / merge_percent) - 1
  46. self.buc = BottomUp(cid=self.cid,
  47. model=self.model, # model is None
  48. batch_size=conf.buc.batch_size,
  49. eval_batch_size=conf.buc.eval_batch_size,
  50. num_classes=num_classes,
  51. train_data=self.train_data,
  52. test_data=self.test_data,
  53. device=device,
  54. initial_epochs=conf.buc.initial_epochs,
  55. local_epochs=conf.buc.local_epochs,
  56. embedding_feature_size=conf.buc.feature_size,
  57. seed=conf.seed)
  58. def train(self, conf, device=CPU):
  59. logger.info("--------- training -------- cid: {}, on {}".format(self.cid, device))
  60. start_time = time.time()
  61. step_to_upload = self.current_step + conf.buc.upload_frequency
  62. total_steps = self.steps
  63. while self.current_step < step_to_upload:
  64. self.current_step += 1
  65. logger.info("current step: {}".format(self.current_step))
  66. logger.info("training transform amount: {}".format(len(self.unlabeled_ys)))
  67. if conf.buc.relabel == RELABEL_GLOBAL:
  68. if self.current_step > 0:
  69. logger.info("-------- bottom-up clustering: relabel train transform with global aggregated model")
  70. self.unlabeled_ys = self.buc.relabel_train_data(device,
  71. self.unlabeled_ys,
  72. self.labeled_ys,
  73. self.nums_to_merge,
  74. size_penalty=conf.buc.size_penalty)
  75. self.train_data.data[self.cid]['y'] = self.unlabeled_ys
  76. self.buc.set_model(self.model, self.current_step)
  77. model = self.buc.train(self.current_step, conf.buc.dynamic_epoch)
  78. self._local_model = copy.deepcopy(self.model)
  79. self.model.load_state_dict(model.state_dict())
  80. rank1, rank5, rank10, mAP = self.buc.evaluate(self.cid)
  81. logger.info("Local test {}, step {}, mAP: {:4.2%}, Rank@1: {:4.2%}, Rank@5: {:4.2%}, Rank@10: {:4.2%}"
  82. .format(self.cid, self.current_step, mAP, rank1, rank5, rank10))
  83. if self.current_step == total_steps:
  84. logger.info("Total steps just reached, force global update")
  85. break
  86. # get new train transform for the next iteration
  87. if self.current_step > total_steps:
  88. logger.info("Total steps reached, skip relabeling")
  89. continue
  90. if conf.buc.relabel == RELABEL_LOCAL:
  91. logger.info("-------- bottom-up clustering: relabel train transform with local trained model")
  92. self.unlabeled_ys = self.buc.relabel_train_data(device,
  93. self.unlabeled_ys,
  94. self.labeled_ys,
  95. self.nums_to_merge,
  96. size_penalty=conf.buc.size_penalty)
  97. self.train_data.data[self.cid]['y'] = self.unlabeled_ys
  98. self.save_model(LOCAL_TEST, device)
  99. self.current_round_time = time.time() - start_time
  100. logger.info("Local training time {}".format(self.current_round_time))
  101. self.track(metric.TRAIN_TIME, self.current_round_time)
  102. self.model = self.model.to(device)
  103. def test(self, conf, device=CPU):
  104. rank1 = 0
  105. if conf.buc.global_evaluation:
  106. logger.info("-------- evaluation -------- {}: {}".format(GLOBAL_TEST, self.cid))
  107. rank1, rank5, rank10, mAP = self.buc.evaluate(self.cid, self.model)
  108. logger.info("Global test {}, step {}, mAP: {:4.2%}, Rank@1: {:4.2%}, Rank@5: {:4.2%}, Rank@10: {:4.2%}"
  109. .format(self.cid, self.current_step, mAP, rank1, rank5, rank10))
  110. self.save_model(GLOBAL_TEST, device)
  111. self._upload_holder = server_pb.UploadContent(
  112. data=codec.marshal(server_pb.Performance(accuracy=rank1, loss=0)), # loss not applicable
  113. type=common_pb.DATA_TYPE_PERFORMANCE,
  114. data_size=len(self.train_data.data[self.cid]['x']),
  115. )
  116. def save_model(self, typ=LOCAL_TEST, device=CPU):
  117. path = os.path.join(os.getcwd(), "saved_models")
  118. if not os.path.exists(path):
  119. os.makedirs(path)
  120. if typ == GLOBAL_TEST:
  121. save_path = os.path.join(path, "{}_global_model_{}.pth".format(self.current_step, time.time()))
  122. if device == 0 or device == CPU:
  123. torch.save(self.model.cpu().state_dict(), save_path)
  124. else:
  125. save_path = os.path.join(path, "{}_{}_local_model_{}.pth".format(self.current_step, self.cid, time.time()))
  126. torch.save(self.model.cpu().state_dict(), save_path)
  127. logger.info("save model {}".format(save_path))
  128. def get_merge_percent(num_images, num_identities, rounds):
  129. nums_to_merge = int((num_images - num_identities) / rounds)
  130. merge_percent = nums_to_merge / num_images
  131. return merge_percent, nums_to_merge
  132. if __name__ == '__main__':
  133. parser = argparse.ArgumentParser(description='')
  134. parser.add_argument('--data_dir', type=str, metavar='PATH', default="datasets/fedreid_data")
  135. parser.add_argument("--datasets", nargs="+", default=["ilids"])
  136. parser.add_argument('--batch_size', type=int, default=16, help='training batch size')
  137. parser.add_argument('--upload_frequency', type=int, default=1, help='frequency of upload for aggregation')
  138. parser.add_argument('--merge_percent', type=float, default=0.05, help='merge percentage of each step')
  139. parser.add_argument('--steps', type=int, default=0, help='steps to decide merge percent')
  140. parser.add_argument('--initial_epochs', type=int, default=20, help='local epochs for first step/round')
  141. parser.add_argument('--local_epochs', type=int, default=1, help='local epochs after first step/round')
  142. parser.add_argument('--dynamic_epoch', default=False, action='store_true', help='dynamic local epochs')
  143. parser.add_argument('--relabel', type=str, default='local', help='use "local" or "global" model to relabel')
  144. parser.add_argument('--merge', default=False, action='store_true')
  145. args = parser.parse_args()
  146. print("args:", args)
  147. # MAIN
  148. train_data = prepare_train_data(args.datasets, args.data_dir)
  149. test_data = prepare_test_data(args.datasets, args.data_dir)
  150. easyfl.register_dataset(train_data, test_data)
  151. easyfl.register_model(BUCModel)
  152. easyfl.register_client(FedUReIDClient)
  153. # configurations
  154. global_evaluation = False
  155. if args.steps:
  156. rounds = args.steps
  157. else:
  158. rounds = int(1 / args.merge_percent)
  159. config = {
  160. "server": {
  161. "rounds": rounds,
  162. },
  163. "client": {
  164. "buc": {
  165. "global_evaluation": global_evaluation,
  166. "relabel": args.relabel,
  167. "initial_epochs": args.initial_epochs,
  168. "local_epochs": args.local_epochs,
  169. "dynamic_epoch": args.dynamic_epoch,
  170. "batch_size": args.batch_size,
  171. "upload_frequency": args.upload_frequency,
  172. "merge_percent": args.merge_percent,
  173. "steps": args.steps,
  174. },
  175. "datasets": args.datasets,
  176. }
  177. }
  178. # For distributed training over multiple GPUs only
  179. try:
  180. rank, local_rank, world_size, host_addr = slurm.setup()
  181. global_evaluation = True if world_size > 1 else False
  182. config["client"]["buc"]["global_evaluation"] = global_evaluation
  183. distributed_config = {
  184. "gpu": world_size,
  185. "distributed": {
  186. "rank": rank,
  187. "local_rank": local_rank,
  188. "world_size": world_size,
  189. "init_method": host_addr,
  190. "backend": "nccl",
  191. },
  192. }
  193. config.update(distributed_config)
  194. except KeyError:
  195. pass
  196. config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yaml")
  197. config = easyfl.load_config(config_file, config)
  198. print("config:", config)
  199. easyfl.init(config, init_all=True)
  200. easyfl.run()