main.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import argparse
  2. from client_with_pgfed import FedSSLWithPgFedClient
  3. from server_with_pgfed import FedSSLWithPgFedServer
  4. import easyfl
  5. from client import FedSSLClient
  6. from dataset import get_semi_supervised_dataset
  7. from easyfl.datasets.data import CIFAR100
  8. from easyfl.distributed import slurm
  9. from model import get_model, BYOLNoEMA, BYOL, BYOLNoSG, BYOLNoEMA_NoSG
  10. from server import FedSSLServer
  11. def run():
  12. parser = argparse.ArgumentParser(description='FedSSL')
  13. parser.add_argument("--task_id", type=str, default="")
  14. parser.add_argument("--dataset", type=str, default='cifar10', help='options: cifar10, cifar100')
  15. parser.add_argument("--data_partition", type=str, default='class', help='options: class, iid, dir')
  16. parser.add_argument("--dir_alpha", type=float, default=0.1, help='alpha for dirichlet sampling')
  17. parser.add_argument('--model', default='byol', type=str, help='options: byol, simsiam, simclr, moco, moco_v2')
  18. parser.add_argument('--encoder_network', default='resnet18', type=str,
  19. help='network architecture of encoder, options: resnet18, resnet50')
  20. parser.add_argument('--predictor_network', default='2_layer', type=str,
  21. help='network of predictor, options: 1_layer, 2_layer')
  22. parser.add_argument('--batch_size', default=128, type=int)
  23. parser.add_argument('--local_epoch', default=5, type=int)
  24. parser.add_argument('--rounds', default=100, type=int)
  25. parser.add_argument('--num_of_clients', default=5, type=int)
  26. parser.add_argument('--clients_per_round', default=5, type=int)
  27. parser.add_argument('--class_per_client', default=2, type=int,
  28. help='for non-IID setting, number of classes each client, based on CIFAR10')
  29. parser.add_argument('--optimizer_type', default='SGD', type=str, help='optimizer type')
  30. parser.add_argument('--lr', default=0.032, type=float)
  31. parser.add_argument('--lr_type', default='cosine', type=str, help='cosine decay learning rate')
  32. parser.add_argument('--random_selection', action='store_true', help='whether randomly select clients')
  33. parser.add_argument('--aggregate_encoder', default='online', type=str, help='options: online, target')
  34. parser.add_argument('--update_encoder', default='online', type=str, help='options: online, target, both, none')
  35. parser.add_argument('--update_predictor', default='global', type=str, help='options: global, local, dapu')
  36. parser.add_argument('--dapu_threshold', default=0.4, type=float, help='DAPU threshold value')
  37. parser.add_argument('--weight_scaler', default=1.0, type=float, help='weight scaler for different class per client')
  38. parser.add_argument('--auto_scaler', default='y', type=str, help='use value to compute auto scaler')
  39. parser.add_argument('--auto_scaler_target', default=0.8, type=float,
  40. help='target weight for the first time scaling')
  41. parser.add_argument('--encoder_weight', type=float, default=0,
  42. help='for ema encoder update, apply on local encoder')
  43. parser.add_argument('--predictor_weight', type=float, default=0,
  44. help='for ema predictor update, apply on local predictor')
  45. parser.add_argument('--test_every', default=10, type=int, help='test every x rounds')
  46. parser.add_argument('--save_model_every', default=10, type=int, help='save model every x rounds')
  47. parser.add_argument('--save_predictor', action='store_true', help='whether save predictor')
  48. parser.add_argument('--semi_supervised', action='store_true', help='whether to train with semi-supervised data')
  49. parser.add_argument('--label_ratio', default=0.01, type=float, help='percentage of labeled data')
  50. parser.add_argument('--gpu', default=0, type=int)
  51. parser.add_argument('--run_count', default=0, type=int)
  52. parser.add_argument('--use_pgfed', default=False, type=bool)
  53. args = parser.parse_args()
  54. print("arguments: ", args)
  55. class_per_client = args.class_per_client
  56. if args.dataset == CIFAR100:
  57. class_per_client *= 10
  58. task_id = args.task_id
  59. if task_id == "":
  60. task_id = f"{args.dataset}_{args.model}_{args.encoder_network}_{args.data_partition}_" \
  61. f"aggregate_{args.aggregate_encoder}_update_{args.update_encoder}_predictor_{args.update_predictor}_" \
  62. f"run{args.run_count}"
  63. momentum_update = True
  64. if args.model == BYOLNoEMA:
  65. args.model = BYOL
  66. momentum_update = False
  67. elif args.model == BYOLNoEMA_NoSG:
  68. args.model = BYOLNoSG
  69. momentum_update = False
  70. image_size = 32
  71. config = {
  72. "task_id": task_id,
  73. "data": {
  74. "dataset": args.dataset,
  75. "num_of_clients": args.num_of_clients,
  76. "split_type": args.data_partition,
  77. "class_per_client": class_per_client,
  78. "data_amount": 1,
  79. "iid_fraction": 1,
  80. "min_size": 10,
  81. "alpha": args.dir_alpha,
  82. },
  83. "model": args.model,
  84. "test_mode": "test_in_server",
  85. "server": {
  86. "batch_size": args.batch_size,
  87. "rounds": args.rounds,
  88. "test_every": args.test_every,
  89. "save_model_every": args.save_model_every,
  90. "clients_per_round": args.clients_per_round,
  91. "random_selection": args.random_selection,
  92. "save_predictor": args.save_predictor,
  93. "test_all": True,
  94. },
  95. "client": {
  96. "drop_last": True,
  97. "batch_size": args.batch_size,
  98. "local_epoch": args.local_epoch,
  99. "optimizer": {
  100. "type": args.optimizer_type,
  101. "lr_type": args.lr_type,
  102. "lr": args.lr,
  103. "momentum": 0.9,
  104. "weight_decay": 0.0005,
  105. },
  106. # application specific
  107. "model": args.model,
  108. "rounds": args.rounds,
  109. "gaussian": False,
  110. "image_size": image_size,
  111. "aggregate_encoder": args.aggregate_encoder,
  112. "update_encoder": args.update_encoder,
  113. "update_predictor": args.update_predictor,
  114. "dapu_threshold": args.dapu_threshold,
  115. "weight_scaler": args.weight_scaler,
  116. "auto_scaler": args.auto_scaler,
  117. "auto_scaler_target": args.auto_scaler_target,
  118. "random_selection": args.random_selection,
  119. "encoder_weight": args.encoder_weight,
  120. "predictor_weight": args.predictor_weight,
  121. "momentum_update": momentum_update,
  122. },
  123. 'resource_heterogeneous': {"grouping_strategy": ""}
  124. }
  125. if args.gpu > 1:
  126. rank, local_rank, world_size, host_addr = slurm.setup()
  127. distribute_config = {
  128. "gpu": world_size,
  129. "distributed": {
  130. "rank": rank,
  131. "local_rank": local_rank,
  132. "world_size": world_size,
  133. "init_method": host_addr
  134. },
  135. }
  136. config.update(distribute_config)
  137. else:
  138. config["gpu"] = args.gpu
  139. if args.semi_supervised:
  140. train_data, test_data, _ = get_semi_supervised_dataset(args.dataset,
  141. args.num_of_clients,
  142. args.data_partition,
  143. class_per_client,
  144. args.label_ratio)
  145. easyfl.register_dataset(train_data, test_data)
  146. model = get_model(args.model, args.encoder_network, args.predictor_network)
  147. easyfl.register_model(model)
  148. if args.use_pgfed:
  149. easyfl.register_client(FedSSLWithPgFedClient)
  150. easyfl.register_server(FedSSLWithPgFedServer)
  151. else:
  152. easyfl.register_client(FedSSLClient)
  153. easyfl.register_server(FedSSLServer)
  154. easyfl.init(config, init_all=True)
  155. easyfl.run()
  156. if __name__ == '__main__':
  157. run()