main.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import argparse
  2. import logging
  3. import os
  4. import easyfl
  5. from client import FedReIDClient
  6. from dataset import prepare_train_data, prepare_test_data
  7. from easyfl.distributed import slurm
  8. from model import Model
  9. logger = logging.getLogger(__name__)
  10. def run():
  11. parser = argparse.ArgumentParser(description='FedReID Application')
  12. parser.add_argument('--task_id', type=str, default="")
  13. parser.add_argument('--data_dir', type=str, metavar='PATH', default="datasets/fedreid")
  14. parser.add_argument("--datasets", nargs="+", default=None, help="list of datasets, e.g., ['ilids']")
  15. parser.add_argument('--test_every', type=int, default=10)
  16. parser.add_argument("--gpu", type=int, default=1, help="default number of GPU")
  17. args = parser.parse_args()
  18. logger.info("arguments: ", args)
  19. train_data = prepare_train_data(args.data_dir, args.datasets)
  20. test_data = prepare_test_data(args.data_dir, args.datasets)
  21. easyfl.register_dataset(train_data, test_data)
  22. easyfl.register_model(Model)
  23. easyfl.register_client(FedReIDClient)
  24. config = {
  25. "task_id": args.task_id,
  26. "gpu": args.gpu,
  27. "client": {
  28. "test_every": args.test_every,
  29. },
  30. "server": {
  31. "test_every": args.test_every
  32. }
  33. }
  34. if args.gpu > 1:
  35. rank, local_rank, world_size, host_addr = slurm.setup()
  36. distribute_config = {
  37. "gpu": world_size,
  38. "distributed": {
  39. "rank": rank,
  40. "local_rank": local_rank,
  41. "world_size": world_size,
  42. "init_method": host_addr
  43. },
  44. }
  45. config.update(distribute_config)
  46. config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yaml")
  47. config = easyfl.load_config(config_file, config)
  48. easyfl.init(config)
  49. easyfl.run()
  50. if __name__ == '__main__':
  51. run()