remote_run.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import argparse
  2. import easyfl
  3. from easyfl.pb import common_pb2 as common_pb
  4. from easyfl.pb import server_service_pb2 as server_pb
  5. from easyfl.protocol import codec
  6. from easyfl.communication import grpc_wrapper
  7. from easyfl.registry import get_clients, SOURCES
  8. parser = argparse.ArgumentParser(description='Federated Server')
  9. parser.add_argument('--server-addr',
  10. type=str,
  11. default="172.18.0.1:23501",
  12. help='Server address')
  13. parser.add_argument('--etcd-addrs',
  14. type=str,
  15. default="172.17.0.1:2379",
  16. help='Etcd address, or list of etcd addrs separated by ","')
  17. parser.add_argument('--source',
  18. type=str,
  19. default="manual",
  20. choices=SOURCES,
  21. help='Source to get the clients')
  22. args = parser.parse_args()
  23. def send_run_request():
  24. config = {
  25. "data": {"dataset": "femnist"},
  26. "model": "lenet",
  27. "test_mode": "test_in_client"
  28. }
  29. print("Server address: {}".format(args.server_addr))
  30. print("Etcd address: {}".format(args.etcd_addrs))
  31. easyfl.init(config)
  32. model = easyfl.init_model()
  33. stub = grpc_wrapper.init_stub(grpc_wrapper.TYPE_SERVER, args.server_addr)
  34. request = server_pb.RunRequest(
  35. model=codec.marshal(model),
  36. )
  37. clients = get_clients(args.source, args.etcd_addrs)
  38. for c in clients:
  39. request.clients.append(server_pb.Client(client_id=c.id, index=c.index, address=c.address))
  40. response = stub.Run(request)
  41. if response.status.code == common_pb.SC_OK:
  42. print("Success")
  43. else:
  44. print(response)
  45. if __name__ == '__main__':
  46. send_run_request()