import argparse import logging import os import easyfl from client import FedReIDClient from dataset import prepare_train_data, prepare_test_data from easyfl.distributed import slurm from model import Model logger = logging.getLogger(__name__) def run(): parser = argparse.ArgumentParser(description='FedReID Application') parser.add_argument('--task_id', type=str, default="") parser.add_argument('--data_dir', type=str, metavar='PATH', default="datasets/fedreid") parser.add_argument("--datasets", nargs="+", default=None, help="list of datasets, e.g., ['ilids']") parser.add_argument('--test_every', type=int, default=10) parser.add_argument("--gpu", type=int, default=1, help="default number of GPU") args = parser.parse_args() logger.info("arguments: ", args) train_data = prepare_train_data(args.data_dir, args.datasets) test_data = prepare_test_data(args.data_dir, args.datasets) easyfl.register_dataset(train_data, test_data) easyfl.register_model(Model) easyfl.register_client(FedReIDClient) config = { "task_id": args.task_id, "gpu": args.gpu, "client": { "test_every": args.test_every, }, "server": { "test_every": args.test_every } } if args.gpu > 1: rank, local_rank, world_size, host_addr = slurm.setup() distribute_config = { "gpu": world_size, "distributed": { "rank": rank, "local_rank": local_rank, "world_size": world_size, "init_method": host_addr }, } config.update(distribute_config) config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "config.yaml") config = easyfl.load_config(config_file, config) easyfl.init(config) easyfl.run() if __name__ == '__main__': run()