service.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. import easyfl
  3. from easyfl.client import base as client_base
  4. from easyfl.server import base as server_base
  5. def start_remote_client(conf=None, train_data=None, test_data=None, model=None, client=None):
  6. """Start a remote client.
  7. Args:
  8. conf (dict): Configurations. optional, Use the configuration loaded from file if not provided. It overwrites the
  9. configurations from file.
  10. train_data (:obj:`FederatedDataset`): Training dataset.
  11. test_data (:obj:`FederatedDataset`): Testing dataset.
  12. model (nn.Module): Model used in client training.
  13. client (:obj:`BaseClient`): Customized federated learning client class.
  14. """
  15. parser = client_base.create_argument_parser()
  16. parser.add_argument('--index', type=int, default=0, help='Client index for quick testing')
  17. parser.add_argument('--config', type=str, default="client_config.yaml", help='Client config file')
  18. args = parser.parse_args()
  19. if os.path.isfile(args.config):
  20. conf = easyfl.load_config(args.config, conf)
  21. if train_data and test_data:
  22. easyfl.register_dataset(train_data, test_data)
  23. elif train_data:
  24. easyfl.register_dataset(train_data, None)
  25. elif test_data:
  26. easyfl.register_dataset(None, test_data)
  27. if model:
  28. easyfl.register_model(model)
  29. if client:
  30. easyfl.register_client(client)
  31. easyfl.init(conf, init_all=False)
  32. easyfl.start_client(args)
  33. def start_remote_server(conf=None, test_data=None, model=None, server=None):
  34. """Start a remote server.
  35. Args:
  36. conf (dict): Configurations. optional, Use the configuration loaded from file if not provided. It overwrites the
  37. configurations from file.
  38. test_data (:obj:`FederatedDataset`): Test dataset for centralized testing on server.
  39. model (nn.Module): Model used in client training.
  40. server (:obj:`BaseServer`): Customized federated learning server class.
  41. """
  42. parser = server_base.create_argument_parser()
  43. parser.add_argument('--config', type=str, default="server_config.yaml", help='Server config file')
  44. args = parser.parse_args()
  45. if os.path.isfile(args.config):
  46. conf = easyfl.load_config(args.config, conf)
  47. if test_data:
  48. easyfl.register_dataset(None, test_data)
  49. if model:
  50. easyfl.register_model(model)
  51. if server:
  52. easyfl.register_server(server)
  53. easyfl.init(conf, init_all=False)
  54. easyfl.start_server(args)