123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- import argparse
- import os
- import torch
- import easyfl
- from easyfl.distributed import slurm
- from client import MASClient
- from server import MASServer
- from dataset import get_dataset
- from losses import parse_tasks
- from models.model import get_model
- STANDALONE_CLIENT_FOLDER = "standalone_clients"
- DEFAULT_CLIENT_ID = "NA"
- def construct_parser(parser):
- parser.add_argument("--task_id", type=str, default="")
- parser.add_argument('--tasks', default='s', help='which tasks to train, options: sdnkt')
- parser.add_argument('--task_groups', default='', help='e.g., groups of tasks separtely by comma, "sd,nkt"')
- parser.add_argument("--dataset", type=str, default='taskonomy', help='')
- parser.add_argument("--arch", type=str, default='xception', help='model architecture')
- parser.add_argument('--data_dir', type=str, help='directory to load data')
- parser.add_argument('--client_file', type=str, default='clients.txt', help='directory to load data')
- parser.add_argument('--client_id', type=str, default=DEFAULT_CLIENT_ID, help='client id for standalone training')
- parser.add_argument('--image_size', default=256, type=int, help='size of image side (images are square)')
- parser.add_argument('--batch_size', default=64, type=int)
- parser.add_argument('--local_epoch', default=5, type=int)
- parser.add_argument('--rounds', default=100, type=int)
- parser.add_argument('--num_of_clients', default=32, type=int)
- parser.add_argument('--clients_per_round', default=5, type=int)
- parser.add_argument('--optimizer_type', default='SGD', type=str, help='optimizer type')
- parser.add_argument('--random_selection', action='store_true', help='whether randomly select clients')
- parser.add_argument('--lr', default=0.1, type=float, metavar='LR', help='initial learning rate')
- parser.add_argument('--lr_type', default="poly", type=str,
- help='learning rate schedule type: poly or custom, custom lr requires stateful client.')
- parser.add_argument('--minimum_learning_rate', default=3e-5, type=float,
- metavar='LR', help='End training when learning rate falls below this value.')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
- parser.add_argument('--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
- parser.add_argument('--test_every', default=10, type=int, help='test every x rounds')
- parser.add_argument('--save_model_every', default=10, type=int, help='save model every x rounds')
- parser.add_argument("--aggregation_content", type=str, default="all", help="aggregation content")
- parser.add_argument("--aggregation_strategy", type=str, default="FedAvg", help="aggregation strategy")
-
- parser.add_argument('--lookahead', default='y', type=str, help='whether use lookahead optimizer')
- parser.add_argument('--lookahead_step', default=10, type=int, help='lookahead every x step')
- parser.add_argument('--num_workers', default=4, type=int, help='number of data loading workers (default: 4)')
- parser.add_argument('--rotate_loss', dest='rotate_loss', action='store_true', help='should loss rotation occur')
- parser.add_argument('--pretrained', default='n', help='use pretrained model')
- parser.add_argument('--pretrained_tasks', default='sdnkt', help='tasks for pretrained')
- parser.add_argument('--load_decoder', default='y', help='whether load pretrained decoder')
- parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode.')
- parser.add_argument('--half', default='n', help='whether use half output')
- parser.add_argument('--half_sized_output', action='store_true', help='output 128x128 rather than 256x256.')
- parser.add_argument('--no_augment', action='store_true', help='Run model fp16 mode.')
- parser.add_argument('--model_limit', default=None, type=int,
- help='Limit the number of training instances from a single 3d building model.')
- parser.add_argument('--task_weights', default=None, type=str,
- help='a comma separated list of numbers one for each task to multiply the loss by.')
- parser.add_argument('-vb', '--virtual_batch_multiplier', default=1, type=int,
- metavar='N', help='number of forward/backward passes per parameter update')
- parser.add_argument('--dist_port', default=23344, type=int)
- parser.add_argument('--run_count', default=0, type=int)
- # Not effective arguments, to be deleted
- parser.add_argument('--maximum_loss_tracking_window', default=2000000, type=int,
- help='maximum loss tracking window (default: 2000000)')
- return parser
- def run(args):
- rank, local_rank, world_size, host_addr = slurm.setup(args.dist_port)
- task_id = args.task_id
- if task_id == "":
- task_id = f"{args.arch}_{args.tasks}_{args.clients_per_round}c{args.num_of_clients}_run{args.run_count}"
- tasks = parse_tasks(args.tasks)
- config = {
- "task_id": task_id,
- "model": args.arch,
- "gpu": world_size,
- "distributed": {"rank": rank, "local_rank": local_rank, "world_size": world_size, "init_method": host_addr},
- "test_mode": "test_in_server",
- "server": {
- "batch_size": args.batch_size,
- "rounds": args.rounds,
- "test_every": args.test_every,
- "save_model_every": args.save_model_every,
- "clients_per_round": args.clients_per_round,
- "test_all": False, # False means do not test clients in the start of training
- "random_selection": args.random_selection,
- "aggregation_content": args.aggregation_content,
- "aggregation_stragtegy": args.aggregation_strategy,
- "track": False,
- },
- "client": {
- "track": False,
- "drop_last": True,
- "batch_size": args.batch_size,
- "local_epoch": args.local_epoch,
- "rounds": args.rounds,
- "optimizer": {
- "type": args.optimizer_type,
- "lr_type": args.lr_type,
- "lr": args.lr,
- "momentum": args.momentum,
- "weight_decay": args.weight_decay,
- },
- "minimum_learning_rate": args.minimum_learning_rate,
-
- "tasks": tasks,
- "task_str": args.tasks,
- "task_weights": args.task_weights,
- "rotate_loss": args.rotate_loss,
- "lookahead": args.lookahead,
- "lookahead_step": args.lookahead_step,
- "num_workers": args.num_workers,
- "fp16": args.fp16,
- "virtual_batch_multiplier": args.virtual_batch_multiplier,
- "maximum_loss_tracking_window": args.maximum_loss_tracking_window,
- },
- "tracking": {"database": os.path.join(os.getcwd(), "tracker", task_id)},
- }
- model = get_model(args.arch, tasks)
- if args.pretrained != "n":
- pretrained_tasks = parse_tasks(args.pretrained_tasks)
- pretrained_model = get_model(args.arch, pretrained_tasks)
- pretrained_path = os.path.join(os.getcwd(), "saved_models", "mas", args.pretrained)
- checkpoint = torch.load(pretrained_path)
- pretrained_model.load_state_dict(checkpoint['state_dict'])
- model.encoder.load_state_dict(pretrained_model.encoder.state_dict())
- if not args.load_decoder == "n":
- print("load decoder")
- pretrained_decoder_keys = list(pretrained_model.task_to_decoder.keys())
- for i, task in enumerate(model.task_to_decoder.keys()):
- pi = pretrained_decoder_keys.index(task)
- model.decoders[i].load_state_dict(pretrained_model.decoders[pi].state_dict())
- augment = not args.no_augment
- client_file = args.client_file
- if args.client_id != DEFAULT_CLIENT_ID:
- client_file = f"{STANDALONE_CLIENT_FOLDER}/{args.client_id}.txt"
- with open(client_file, "w") as f:
- f.write(args.client_id)
- if args.half == 'y':
- args.half_sized_output = True
- print("train client file:", client_file)
- print("test client file:", args.client_file)
- train_data, val_data, test_data = get_dataset(args.data_dir,
- train_client_file=client_file,
- test_client_file=args.client_file,
- tasks=tasks,
- image_size=args.image_size,
- model_limit=args.model_limit,
- half_sized_output=args.half_sized_output,
- augment=augment)
- easyfl.register_dataset(train_data, test_data, val_data)
- easyfl.register_model(model)
- easyfl.register_client(MASClient)
- easyfl.register_server(MASServer)
- easyfl.init(config, init_all=True)
- easyfl.run()
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='MAS')
- parser = construct_parser(parser)
- args = parser.parse_args()
- print("arguments: ", args)
- run(args)
|