main.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import argparse
  2. import easyfl
  3. from client import FedSSLClient
  4. from dataset import get_semi_supervised_dataset
  5. from easyfl.datasets.data import CIFAR100
  6. from easyfl.distributed import slurm
  7. from model import get_model, BYOLNoEMA, BYOL, BYOLNoSG, BYOLNoEMA_NoSG
  8. from server import FedSSLServer
  9. def run():
  10. parser = argparse.ArgumentParser(description='FedSSL')
  11. parser.add_argument("--task_id", type=str, default="")
  12. parser.add_argument("--dataset", type=str, default='cifar10', help='options: cifar10, cifar100')
  13. parser.add_argument("--data_partition", type=str, default='class', help='options: class, iid, dir')
  14. parser.add_argument("--dir_alpha", type=float, default=0.1, help='alpha for dirichlet sampling')
  15. parser.add_argument('--model', default='byol', type=str, help='options: byol, simsiam, simclr, moco, moco_v2')
  16. parser.add_argument('--encoder_network', default='resnet18', type=str,
  17. help='network architecture of encoder, options: resnet18, resnet50')
  18. parser.add_argument('--predictor_network', default='2_layer', type=str,
  19. help='network of predictor, options: 1_layer, 2_layer')
  20. parser.add_argument('--batch_size', default=128, type=int)
  21. parser.add_argument('--local_epoch', default=5, type=int)
  22. parser.add_argument('--rounds', default=100, type=int)
  23. parser.add_argument('--num_of_clients', default=5, type=int)
  24. parser.add_argument('--clients_per_round', default=5, type=int)
  25. parser.add_argument('--class_per_client', default=2, type=int,
  26. help='for non-IID setting, number of classes each client, based on CIFAR10')
  27. parser.add_argument('--optimizer_type', default='SGD', type=str, help='optimizer type')
  28. parser.add_argument('--lr', default=0.032, type=float)
  29. parser.add_argument('--lr_type', default='cosine', type=str, help='cosine decay learning rate')
  30. parser.add_argument('--random_selection', action='store_true', help='whether randomly select clients')
  31. parser.add_argument('--aggregate_encoder', default='online', type=str, help='options: online, target')
  32. parser.add_argument('--update_encoder', default='online', type=str, help='options: online, target, both, none')
  33. parser.add_argument('--update_predictor', default='global', type=str, help='options: global, local, dapu')
  34. parser.add_argument('--dapu_threshold', default=0.4, type=float, help='DAPU threshold value')
  35. parser.add_argument('--weight_scaler', default=1.0, type=float, help='weight scaler for different class per client')
  36. parser.add_argument('--auto_scaler', default='y', type=str, help='use value to compute auto scaler')
  37. parser.add_argument('--auto_scaler_target', default=0.8, type=float,
  38. help='target weight for the first time scaling')
  39. parser.add_argument('--encoder_weight', type=float, default=0,
  40. help='for ema encoder update, apply on local encoder')
  41. parser.add_argument('--predictor_weight', type=float, default=0,
  42. help='for ema predictor update, apply on local predictor')
  43. parser.add_argument('--test_every', default=10, type=int, help='test every x rounds')
  44. parser.add_argument('--save_model_every', default=10, type=int, help='save model every x rounds')
  45. parser.add_argument('--save_predictor', action='store_true', help='whether save predictor')
  46. parser.add_argument('--semi_supervised', action='store_true', help='whether to train with semi-supervised data')
  47. parser.add_argument('--label_ratio', default=0.01, type=float, help='percentage of labeled data')
  48. parser.add_argument('--gpu', default=0, type=int)
  49. parser.add_argument('--run_count', default=0, type=int)
  50. args = parser.parse_args()
  51. print("arguments: ", args)
  52. class_per_client = args.class_per_client
  53. if args.dataset == CIFAR100:
  54. class_per_client *= 10
  55. task_id = args.task_id
  56. if task_id == "":
  57. task_id = f"{args.dataset}_{args.model}_{args.encoder_network}_{args.data_partition}_" \
  58. f"aggregate_{args.aggregate_encoder}_update_{args.update_encoder}_predictor_{args.update_predictor}_" \
  59. f"run{args.run_count}"
  60. momentum_update = True
  61. if args.model == BYOLNoEMA:
  62. args.model = BYOL
  63. momentum_update = False
  64. elif args.model == BYOLNoEMA_NoSG:
  65. args.model = BYOLNoSG
  66. momentum_update = False
  67. image_size = 32
  68. config = {
  69. "task_id": task_id,
  70. "data": {
  71. "dataset": args.dataset,
  72. "num_of_clients": args.num_of_clients,
  73. "split_type": args.data_partition,
  74. "class_per_client": class_per_client,
  75. "data_amount": 1,
  76. "iid_fraction": 1,
  77. "min_size": 10,
  78. "alpha": args.dir_alpha,
  79. },
  80. "model": args.model,
  81. "test_mode": "test_in_server",
  82. "server": {
  83. "batch_size": args.batch_size,
  84. "rounds": args.rounds,
  85. "test_every": args.test_every,
  86. "save_model_every": args.save_model_every,
  87. "clients_per_round": args.clients_per_round,
  88. "random_selection": args.random_selection,
  89. "save_predictor": args.save_predictor,
  90. "test_all": True,
  91. },
  92. "client": {
  93. "drop_last": True,
  94. "batch_size": args.batch_size,
  95. "local_epoch": args.local_epoch,
  96. "optimizer": {
  97. "type": args.optimizer_type,
  98. "lr_type": args.lr_type,
  99. "lr": args.lr,
  100. "momentum": 0.9,
  101. "weight_decay": 0.0005,
  102. },
  103. # application specific
  104. "model": args.model,
  105. "rounds": args.rounds,
  106. "gaussian": False,
  107. "image_size": image_size,
  108. "aggregate_encoder": args.aggregate_encoder,
  109. "update_encoder": args.update_encoder,
  110. "update_predictor": args.update_predictor,
  111. "dapu_threshold": args.dapu_threshold,
  112. "weight_scaler": args.weight_scaler,
  113. "auto_scaler": args.auto_scaler,
  114. "auto_scaler_target": args.auto_scaler_target,
  115. "random_selection": args.random_selection,
  116. "encoder_weight": args.encoder_weight,
  117. "predictor_weight": args.predictor_weight,
  118. "momentum_update": momentum_update,
  119. },
  120. 'resource_heterogeneous': {"grouping_strategy": ""}
  121. }
  122. if args.gpu > 1:
  123. rank, local_rank, world_size, host_addr = slurm.setup()
  124. distribute_config = {
  125. "gpu": world_size,
  126. "distributed": {
  127. "rank": rank,
  128. "local_rank": local_rank,
  129. "world_size": world_size,
  130. "init_method": host_addr
  131. },
  132. }
  133. config.update(distribute_config)
  134. else:
  135. config["gpu"] = args.gpu
  136. if args.semi_supervised:
  137. train_data, test_data, _ = get_semi_supervised_dataset(args.dataset,
  138. args.num_of_clients,
  139. args.data_partition,
  140. class_per_client,
  141. args.label_ratio)
  142. easyfl.register_dataset(train_data, test_data)
  143. model = get_model(args.model, args.encoder_network, args.predictor_network)
  144. easyfl.register_model(model)
  145. easyfl.register_client(FedSSLClient)
  146. easyfl.register_server(FedSSLServer)
  147. easyfl.init(config, init_all=True)
  148. easyfl.run()
  149. if __name__ == '__main__':
  150. run()