base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. import argparse
  2. import copy
  3. import logging
  4. import time
  5. import torch
  6. from easyfl.client.service import ClientService
  7. from easyfl.communication import grpc_wrapper
  8. from easyfl.distributed.distributed import CPU
  9. from easyfl.pb import common_pb2 as common_pb
  10. from easyfl.pb import server_service_pb2 as server_pb
  11. from easyfl.protocol import codec
  12. from easyfl.tracking import metric
  13. from easyfl.tracking.client import init_tracking
  14. from easyfl.tracking.evaluation import model_size
  15. logger = logging.getLogger(__name__)
  16. def create_argument_parser():
  17. """Create argument parser with arguments/configurations for starting remote client service.
  18. Returns:
  19. argparse.ArgumentParser: Parser with client service arguments.
  20. """
  21. parser = argparse.ArgumentParser(description='Federated Client')
  22. parser.add_argument('--local-port',
  23. type=int,
  24. default=23000,
  25. help='Listen port of the client')
  26. parser.add_argument('--server-addr',
  27. type=str,
  28. default="localhost:22999",
  29. help='Address of server in [IP]:[PORT] format')
  30. parser.add_argument('--tracker-addr',
  31. type=str,
  32. default="localhost:12666",
  33. help='Address of tracking service in [IP]:[PORT] format')
  34. parser.add_argument('--is-remote',
  35. type=bool,
  36. default=False,
  37. help='Whether start as a remote client.')
  38. return parser
  39. class BaseClient(object):
  40. """Default implementation of federated learning client.
  41. Args:
  42. cid (str): Client id.
  43. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  44. train_data (:obj:`FederatedDataset`): Training dataset.
  45. test_data (:obj:`FederatedDataset`): Test dataset.
  46. device (str): Hardware device for training, cpu or cuda devices.
  47. sleep_time (float): Duration of on hold after training to simulate stragglers.
  48. is_remote (bool): Whether start remote training.
  49. local_port (int): Port of remote client service.
  50. server_addr (str): Remote server service grpc address.
  51. tracker_addr (str): Remote tracking service grpc address.
  52. Override the class and functions to implement customized client.
  53. Example:
  54. >>> from easyfl.client import BaseClient
  55. >>> class CustomizedClient(BaseClient):
  56. >>> def __init__(self, cid, conf, train_data, test_data, device, **kwargs):
  57. >>> super(CustomizedClient, self).__init__(cid, conf, train_data, test_data, device, **kwargs)
  58. >>> pass # more initialization of attributes.
  59. >>>
  60. >>> def train(self, conf, device=CPU):
  61. >>> # Implement customized client training method, which overwrites the default training method.
  62. >>> pass
  63. """
  64. def __init__(self,
  65. id,
  66. cid,
  67. conf,
  68. train_data,
  69. test_data,
  70. device,
  71. sleep_time=0,
  72. is_remote=False,
  73. local_port=23000,
  74. server_addr="localhost:22999",
  75. tracker_addr="localhost:12666"):
  76. self.id = id
  77. self.cid = cid
  78. self.conf = conf
  79. self.train_data = train_data
  80. self.train_loader = None
  81. self.test_data = test_data
  82. self.test_loader = None
  83. self.device = device
  84. self.round_time = 0
  85. self.train_time = 0
  86. self.test_time = 0
  87. self.train_accuracy = []
  88. self.train_loss = []
  89. self.test_accuracy = 0
  90. self.test_loss = 0
  91. self.profiled = False
  92. self._sleep_time = sleep_time
  93. self.compressed_model = None
  94. self.model = None
  95. self._upload_holder = server_pb.UploadContent()
  96. self.is_remote = is_remote
  97. self.local_port = local_port
  98. self._server_addr = server_addr
  99. self._tracker_addr = tracker_addr
  100. self._server_stub = None
  101. self._tracker = None
  102. self._is_train = True
  103. if conf.track:
  104. self._tracker = init_tracking(init_store=False)
  105. def run_train(self, model, conf):
  106. """Conduct training on clients.
  107. Args:
  108. model (nn.Module): Model to train.
  109. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  110. Returns:
  111. :obj:`UploadRequest`: Training contents. Unify the interface for both local and remote operations.
  112. """
  113. self.conf = conf
  114. if conf.track:
  115. self._tracker.set_client_context(conf.task_id, conf.round_id, self.cid)
  116. self._is_train = True
  117. self.download(model)
  118. self.track(metric.TRAIN_DOWNLOAD_SIZE, model_size(model))
  119. self.decompression()
  120. self.pre_train()
  121. self.train(conf, self.device)
  122. self.post_train()
  123. self.track(metric.TRAIN_ACCURACY, self.train_accuracy)
  124. self.track(metric.TRAIN_LOSS, self.train_loss)
  125. self.track(metric.TRAIN_TIME, self.train_time)
  126. if conf.local_test:
  127. self.test_local()
  128. self.compression()
  129. self.track(metric.TRAIN_UPLOAD_SIZE, model_size(self.compressed_model))
  130. self.encryption()
  131. return self.upload()
  132. def run_test(self, model, conf):
  133. """Conduct testing on clients.
  134. Args:
  135. model (nn.Module): Model to test.
  136. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  137. Returns:
  138. :obj:`UploadRequest`: Testing contents. Unify the interface for both local and remote operations.
  139. """
  140. self.conf = conf
  141. if conf.track:
  142. reset = not self._is_train
  143. self._tracker.set_client_context(conf.task_id, conf.round_id, self.cid, reset_client=reset)
  144. self._is_train = False
  145. self.download(model)
  146. self.track(metric.TEST_DOWNLOAD_SIZE, model_size(model))
  147. self.decompression()
  148. self.pre_test()
  149. self.test(conf, self.device)
  150. self.post_test()
  151. self.track(metric.TEST_ACCURACY, float(self.test_accuracy))
  152. self.track(metric.TEST_LOSS, float(self.test_loss))
  153. self.track(metric.TEST_TIME, self.test_time)
  154. return self.upload()
  155. def download(self, model):
  156. """Download model from the server.
  157. Args:
  158. model (nn.Module): Global model distributed from the server.
  159. """
  160. if self.compressed_model:
  161. self.compressed_model.load_state_dict(model.state_dict())
  162. else:
  163. self.compressed_model = copy.deepcopy(model)
  164. def decompression(self):
  165. """Decompressed model. It can be further implemented when the model is compressed in the server."""
  166. self.model = self.compressed_model
  167. def pre_train(self):
  168. """Preprocessing before training."""
  169. pass
  170. def train(self, conf, device=CPU):
  171. """Execute client training.
  172. Args:
  173. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  174. device (str): Hardware device for training, cpu or cuda devices.
  175. """
  176. start_time = time.time()
  177. loss_fn, optimizer = self.pretrain_setup(conf, device)
  178. self.train_loss = []
  179. for i in range(conf.local_epoch):
  180. batch_loss = []
  181. for batched_x, batched_y in self.train_loader:
  182. x, y = batched_x.to(device), batched_y.to(device)
  183. optimizer.zero_grad()
  184. out = self.model(x)
  185. loss = loss_fn(out, y)
  186. loss.backward()
  187. optimizer.step()
  188. batch_loss.append(loss.item())
  189. current_epoch_loss = sum(batch_loss) / len(batch_loss)
  190. self.train_loss.append(float(current_epoch_loss))
  191. logger.debug("Client {}, local epoch: {}, loss: {}".format(self.cid, i, current_epoch_loss))
  192. self.train_time = time.time() - start_time
  193. logger.debug("Client {}, Train Time: {}".format(self.cid, self.train_time))
  194. def post_train(self):
  195. """Postprocessing after training."""
  196. pass
  197. def pretrain_setup(self, conf, device):
  198. """Setup loss function and optimizer before training."""
  199. self.simulate_straggler()
  200. self.model.train()
  201. self.model.to(device)
  202. loss_fn = self.load_loss_fn(conf)
  203. optimizer = self.load_optimizer(conf)
  204. if self.train_loader is None:
  205. self.train_loader = self.load_loader(conf)
  206. return loss_fn, optimizer
  207. def load_loss_fn(self, conf):
  208. return torch.nn.CrossEntropyLoss()
  209. def load_optimizer(self, conf):
  210. """Load training optimizer. Implemented Adam and SGD."""
  211. if conf.optimizer.type == "Adam":
  212. optimizer = torch.optim.Adam(self.model.parameters(), lr=conf.optimizer.lr)
  213. else:
  214. # default using optimizer SGD
  215. optimizer = torch.optim.SGD(self.model.parameters(),
  216. lr=conf.optimizer.lr,
  217. momentum=conf.optimizer.momentum,
  218. weight_decay=conf.optimizer.weight_decay)
  219. return optimizer
  220. def load_loader(self, conf):
  221. """Load the training data loader.
  222. Args:
  223. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  224. Returns:
  225. torch.utils.data.DataLoader: Data loader.
  226. """
  227. return self.train_data.loader(conf.batch_size, self.cid, shuffle=True, seed=conf.seed)
  228. def test_local(self):
  229. """Test client local model after training."""
  230. pass
  231. def pre_test(self):
  232. """Preprocessing before testing."""
  233. pass
  234. def test(self, conf, device=CPU):
  235. """Execute client testing.
  236. Args:
  237. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  238. device (str): Hardware device for training, cpu or cuda devices.
  239. """
  240. begin_test_time = time.time()
  241. self.model.eval()
  242. self.model.to(device)
  243. loss_fn = self.load_loss_fn(conf)
  244. if self.test_loader is None:
  245. self.test_loader = self.test_data.loader(conf.test_batch_size, self.cid, shuffle=False, seed=conf.seed)
  246. # TODO: make evaluation metrics a separate package and apply it here.
  247. self.test_loss = 0
  248. correct = 0
  249. with torch.no_grad():
  250. for batched_x, batched_y in self.test_loader:
  251. x = batched_x.to(device)
  252. y = batched_y.to(device)
  253. log_probs = self.model(x)
  254. loss = loss_fn(log_probs, y)
  255. _, y_pred = torch.max(log_probs, -1)
  256. correct += y_pred.eq(y.data.view_as(y_pred)).long().cpu().sum()
  257. self.test_loss += loss.item()
  258. test_size = self.test_data.size(self.cid)
  259. self.test_loss /= test_size
  260. self.test_accuracy = 100.0 * float(correct) / test_size
  261. logger.debug('Client {}, testing -- Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
  262. self.cid, self.test_loss, correct, test_size, self.test_accuracy))
  263. self.test_time = time.time() - begin_test_time
  264. self.model = self.model.cpu()
  265. def post_test(self):
  266. """Postprocessing after testing."""
  267. pass
  268. def encryption(self):
  269. """Encrypt the client local model."""
  270. # TODO: encryption of model, remember to track encrypted model instead of compressed one after implementation.
  271. pass
  272. def compression(self):
  273. """Compress the client local model after training and before uploading to the server."""
  274. self.compressed_model = self.model
  275. def upload(self):
  276. """Upload the messages from client to the server.
  277. Returns:
  278. :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations.
  279. Only applicable for local training as remote training upload through a gRPC request.
  280. """
  281. request = self.construct_upload_request()
  282. if not self.is_remote:
  283. self.post_upload()
  284. return request
  285. self.upload_remotely(request)
  286. self.post_upload()
  287. def post_upload(self):
  288. """Postprocessing after uploading training/testing results."""
  289. pass
  290. def construct_upload_request(self):
  291. """Construct client upload request for training updates and testing results.
  292. Returns:
  293. :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations.
  294. """
  295. data = codec.marshal(server_pb.Performance(accuracy=self.test_accuracy, loss=self.test_loss))
  296. typ = common_pb.DATA_TYPE_PERFORMANCE
  297. try:
  298. if self._is_train:
  299. data = codec.marshal(copy.deepcopy(self.compressed_model))
  300. typ = common_pb.DATA_TYPE_PARAMS
  301. data_size = self.train_data.size(self.cid)
  302. else:
  303. data_size = 1 if not self.test_data else self.test_data.size(self.cid)
  304. except KeyError:
  305. # When the datasize cannot be get from dataset, default to use equal aggregate
  306. data_size = 1
  307. m = self._tracker.get_client_metric().to_proto() if self._tracker else common_pb.ClientMetric()
  308. return server_pb.UploadRequest(
  309. task_id=self.conf.task_id,
  310. round_id=self.conf.round_id,
  311. client_id=self.cid,
  312. content=server_pb.UploadContent(
  313. data=data,
  314. type=typ,
  315. data_size=data_size,
  316. metric=m,
  317. ),
  318. )
  319. def upload_remotely(self, request):
  320. """Send upload request to remote server via gRPC.
  321. Args:
  322. request (:obj:`UploadRequest`): Upload request.
  323. """
  324. start_time = time.time()
  325. self.connect_to_server()
  326. resp = self._server_stub.Upload(request)
  327. upload_time = time.time() - start_time
  328. m = metric.TRAIN_UPLOAD_TIME if self._is_train else metric.TEST_UPLOAD_TIME
  329. self.track(m, upload_time)
  330. logger.info("client upload time: {}s".format(upload_time))
  331. if resp.status.code == common_pb.SC_OK:
  332. logger.info("Uploaded remotely to the server successfully\n")
  333. else:
  334. logger.error("Failed to upload, code: {}, message: {}\n".format(resp.status.code, resp.status.message))
  335. # Functions for remote services.
  336. def start_service(self):
  337. """Start client service."""
  338. if self.is_remote:
  339. grpc_wrapper.start_service(grpc_wrapper.TYPE_CLIENT, ClientService(self), self.local_port)
  340. def connect_to_server(self):
  341. """Establish connection between the client and the server."""
  342. if self.is_remote and self._server_stub is None:
  343. self._server_stub = grpc_wrapper.init_stub(grpc_wrapper.TYPE_SERVER, self._server_addr)
  344. logger.info("Successfully connected to gRPC server {}".format(self._server_addr))
  345. def operate(self, model, conf, index, is_train=True):
  346. """A wrapper over operations (training/testing) on clients.
  347. Args:
  348. model (nn.Module): Model for operations.
  349. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  350. index (int): Client index in the client list, for retrieving data. TODO: improvement.
  351. is_train (bool): The flag to indicate whether the operation is training, otherwise testing.
  352. """
  353. try:
  354. # Load the data index depending on server request
  355. self.cid = self.train_data.users[index]
  356. except IndexError:
  357. logger.error("Data index exceed the available data, abort training")
  358. return
  359. if self.conf.track and self._tracker is None:
  360. self._tracker = init_tracking(init_store=False)
  361. if is_train:
  362. logger.info("Train on data index {}, client: {}".format(index, self.cid))
  363. self.run_train(model, conf)
  364. else:
  365. logger.info("Test on data index {}, client: {}".format(index, self.cid))
  366. self.run_test(model, conf)
  367. # Functions for tracking.
  368. def track(self, metric_name, value):
  369. """Track a metric.
  370. Args:
  371. metric_name (str): The name of the metric.
  372. value (str|int|float|bool|dict|list): The value of the metric.
  373. """
  374. if not self.conf.track or self._tracker is None:
  375. logger.debug("Tracker not available, Tracking not supported")
  376. return
  377. self._tracker.track_client(metric_name, value)
  378. def save_metrics(self):
  379. """Save client metrics to database."""
  380. # TODO: not tested
  381. if self._tracker is None:
  382. logger.debug("Tracker not available, no saving")
  383. return
  384. self._tracker.save_client()
  385. # Functions for simulation.
  386. def simulate_straggler(self):
  387. """Simulate straggler effect of system heterogeneity."""
  388. if self._sleep_time > 0:
  389. time.sleep(self._sleep_time)