main.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import argparse
  2. import os
  3. import torch
  4. import easyfl
  5. from easyfl.distributed import slurm
  6. from client import MASClient
  7. from server import MASServer
  8. from dataset import get_dataset
  9. from losses import parse_tasks
  10. from models.model import get_model
  11. STANDALONE_CLIENT_FOLDER = "standalone_clients"
  12. DEFAULT_CLIENT_ID = "NA"
  13. def construct_parser(parser):
  14. parser.add_argument("--task_id", type=str, default="")
  15. parser.add_argument('--tasks', default='s', help='which tasks to train, options: sdnkt')
  16. parser.add_argument('--task_groups', default='', help='e.g., groups of tasks separtely by comma, "sd,nkt"')
  17. parser.add_argument("--dataset", type=str, default='taskonomy', help='')
  18. parser.add_argument("--arch", type=str, default='xception', help='model architecture')
  19. parser.add_argument('--data_dir', type=str, help='directory to load data')
  20. parser.add_argument('--client_file', type=str, default='clients.txt', help='directory to load data')
  21. parser.add_argument('--client_id', type=str, default=DEFAULT_CLIENT_ID, help='client id for standalone training')
  22. parser.add_argument('--image_size', default=256, type=int, help='size of image side (images are square)')
  23. parser.add_argument('--batch_size', default=64, type=int)
  24. parser.add_argument('--local_epoch', default=5, type=int)
  25. parser.add_argument('--rounds', default=100, type=int)
  26. parser.add_argument('--num_of_clients', default=32, type=int)
  27. parser.add_argument('--clients_per_round', default=5, type=int)
  28. parser.add_argument('--optimizer_type', default='SGD', type=str, help='optimizer type')
  29. parser.add_argument('--random_selection', action='store_true', help='whether randomly select clients')
  30. parser.add_argument('--lr', default=0.1, type=float, metavar='LR', help='initial learning rate')
  31. parser.add_argument('--lr_type', default="poly", type=str,
  32. help='learning rate schedule type: poly or custom, custom lr requires stateful client.')
  33. parser.add_argument('--minimum_learning_rate', default=3e-5, type=float,
  34. metavar='LR', help='End training when learning rate falls below this value.')
  35. parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
  36. parser.add_argument('--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
  37. parser.add_argument('--test_every', default=10, type=int, help='test every x rounds')
  38. parser.add_argument('--save_model_every', default=10, type=int, help='save model every x rounds')
  39. parser.add_argument("--aggregation_content", type=str, default="all", help="aggregation content")
  40. parser.add_argument("--aggregation_strategy", type=str, default="FedAvg", help="aggregation strategy")
  41. parser.add_argument('--lookahead', default='y', type=str, help='whether use lookahead optimizer')
  42. parser.add_argument('--lookahead_step', default=10, type=int, help='lookahead every x step')
  43. parser.add_argument('--num_workers', default=4, type=int, help='number of data loading workers (default: 4)')
  44. parser.add_argument('--rotate_loss', dest='rotate_loss', action='store_true', help='should loss rotation occur')
  45. parser.add_argument('--pretrained', default='n', help='use pretrained model')
  46. parser.add_argument('--pretrained_tasks', default='sdnkt', help='tasks for pretrained')
  47. parser.add_argument('--load_decoder', default='y', help='whether load pretrained decoder')
  48. parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode.')
  49. parser.add_argument('--half', default='n', help='whether use half output')
  50. parser.add_argument('--half_sized_output', action='store_true', help='output 128x128 rather than 256x256.')
  51. parser.add_argument('--no_augment', action='store_true', help='Run model fp16 mode.')
  52. parser.add_argument('--model_limit', default=None, type=int,
  53. help='Limit the number of training instances from a single 3d building model.')
  54. parser.add_argument('--task_weights', default=None, type=str,
  55. help='a comma separated list of numbers one for each task to multiply the loss by.')
  56. parser.add_argument('-vb', '--virtual_batch_multiplier', default=1, type=int,
  57. metavar='N', help='number of forward/backward passes per parameter update')
  58. parser.add_argument('--dist_port', default=23344, type=int)
  59. parser.add_argument('--run_count', default=0, type=int)
  60. # Not effective arguments, to be deleted
  61. parser.add_argument('--maximum_loss_tracking_window', default=2000000, type=int,
  62. help='maximum loss tracking window (default: 2000000)')
  63. return parser
  64. def run(args):
  65. rank, local_rank, world_size, host_addr = slurm.setup(args.dist_port)
  66. task_id = args.task_id
  67. if task_id == "":
  68. task_id = f"{args.arch}_{args.tasks}_{args.clients_per_round}c{args.num_of_clients}_run{args.run_count}"
  69. tasks = parse_tasks(args.tasks)
  70. config = {
  71. "task_id": task_id,
  72. "model": args.arch,
  73. "gpu": world_size,
  74. "distributed": {"rank": rank, "local_rank": local_rank, "world_size": world_size, "init_method": host_addr},
  75. "test_mode": "test_in_server",
  76. "server": {
  77. "batch_size": args.batch_size,
  78. "rounds": args.rounds,
  79. "test_every": args.test_every,
  80. "save_model_every": args.save_model_every,
  81. "clients_per_round": args.clients_per_round,
  82. "test_all": False, # False means do not test clients in the start of training
  83. "random_selection": args.random_selection,
  84. "aggregation_content": args.aggregation_content,
  85. "aggregation_stragtegy": args.aggregation_strategy,
  86. "track": False,
  87. },
  88. "client": {
  89. "track": False,
  90. "drop_last": True,
  91. "batch_size": args.batch_size,
  92. "local_epoch": args.local_epoch,
  93. "rounds": args.rounds,
  94. "optimizer": {
  95. "type": args.optimizer_type,
  96. "lr_type": args.lr_type,
  97. "lr": args.lr,
  98. "momentum": args.momentum,
  99. "weight_decay": args.weight_decay,
  100. },
  101. "minimum_learning_rate": args.minimum_learning_rate,
  102. "tasks": tasks,
  103. "task_str": args.tasks,
  104. "task_weights": args.task_weights,
  105. "rotate_loss": args.rotate_loss,
  106. "lookahead": args.lookahead,
  107. "lookahead_step": args.lookahead_step,
  108. "num_workers": args.num_workers,
  109. "fp16": args.fp16,
  110. "virtual_batch_multiplier": args.virtual_batch_multiplier,
  111. "maximum_loss_tracking_window": args.maximum_loss_tracking_window,
  112. },
  113. "tracking": {"database": os.path.join(os.getcwd(), "tracker", task_id)},
  114. }
  115. model = get_model(args.arch, tasks)
  116. if args.pretrained != "n":
  117. pretrained_tasks = parse_tasks(args.pretrained_tasks)
  118. pretrained_model = get_model(args.arch, pretrained_tasks)
  119. pretrained_path = os.path.join(os.getcwd(), "saved_models", "mas", args.pretrained)
  120. checkpoint = torch.load(pretrained_path)
  121. pretrained_model.load_state_dict(checkpoint['state_dict'])
  122. model.encoder.load_state_dict(pretrained_model.encoder.state_dict())
  123. if not args.load_decoder == "n":
  124. print("load decoder")
  125. pretrained_decoder_keys = list(pretrained_model.task_to_decoder.keys())
  126. for i, task in enumerate(model.task_to_decoder.keys()):
  127. pi = pretrained_decoder_keys.index(task)
  128. model.decoders[i].load_state_dict(pretrained_model.decoders[pi].state_dict())
  129. augment = not args.no_augment
  130. client_file = args.client_file
  131. if args.client_id != DEFAULT_CLIENT_ID:
  132. client_file = f"{STANDALONE_CLIENT_FOLDER}/{args.client_id}.txt"
  133. with open(client_file, "w") as f:
  134. f.write(args.client_id)
  135. if args.half == 'y':
  136. args.half_sized_output = True
  137. print("train client file:", client_file)
  138. print("test client file:", args.client_file)
  139. train_data, val_data, test_data = get_dataset(args.data_dir,
  140. train_client_file=client_file,
  141. test_client_file=args.client_file,
  142. tasks=tasks,
  143. image_size=args.image_size,
  144. model_limit=args.model_limit,
  145. half_sized_output=args.half_sized_output,
  146. augment=augment)
  147. easyfl.register_dataset(train_data, test_data, val_data)
  148. easyfl.register_model(model)
  149. easyfl.register_client(MASClient)
  150. easyfl.register_server(MASServer)
  151. easyfl.init(config, init_all=True)
  152. easyfl.run()
  153. if __name__ == '__main__':
  154. parser = argparse.ArgumentParser(description='MAS')
  155. parser = construct_parser(parser)
  156. args = parser.parse_args()
  157. print("arguments: ", args)
  158. run(args)