base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. import argparse
  2. import copy
  3. import logging
  4. import time
  5. import torch
  6. from easyfl.client.service import ClientService
  7. from easyfl.communication import grpc_wrapper
  8. from easyfl.distributed.distributed import CPU
  9. from easyfl.pb import common_pb2 as common_pb
  10. from easyfl.pb import server_service_pb2 as server_pb
  11. from easyfl.protocol import codec
  12. from easyfl.tracking import metric
  13. from easyfl.tracking.client import init_tracking
  14. from easyfl.tracking.evaluation import model_size
  15. logger = logging.getLogger(__name__)
  16. def create_argument_parser():
  17. """Create argument parser with arguments/configurations for starting remote client service.
  18. Returns:
  19. argparse.ArgumentParser: Parser with client service arguments.
  20. """
  21. parser = argparse.ArgumentParser(description='Federated Client')
  22. parser.add_argument('--local-port',
  23. type=int,
  24. default=23000,
  25. help='Listen port of the client')
  26. parser.add_argument('--server-addr',
  27. type=str,
  28. default="localhost:22999",
  29. help='Address of server in [IP]:[PORT] format')
  30. parser.add_argument('--tracker-addr',
  31. type=str,
  32. default="localhost:12666",
  33. help='Address of tracking service in [IP]:[PORT] format')
  34. parser.add_argument('--is-remote',
  35. type=bool,
  36. default=False,
  37. help='Whether start as a remote client.')
  38. return parser
  39. class BaseClient(object):
  40. """Default implementation of federated learning client.
  41. Args:
  42. cid (str): Client id.
  43. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  44. train_data (:obj:`FederatedDataset`): Training dataset.
  45. test_data (:obj:`FederatedDataset`): Test dataset.
  46. device (str): Hardware device for training, cpu or cuda devices.
  47. sleep_time (float): Duration of on hold after training to simulate stragglers.
  48. is_remote (bool): Whether start remote training.
  49. local_port (int): Port of remote client service.
  50. server_addr (str): Remote server service grpc address.
  51. tracker_addr (str): Remote tracking service grpc address.
  52. Override the class and functions to implement customized client.
  53. Example:
  54. >>> from easyfl.client import BaseClient
  55. >>> class CustomizedClient(BaseClient):
  56. >>> def __init__(self, cid, conf, train_data, test_data, device, **kwargs):
  57. >>> super(CustomizedClient, self).__init__(cid, conf, train_data, test_data, device, **kwargs)
  58. >>> pass # more initialization of attributes.
  59. >>>
  60. >>> def train(self, conf, device=CPU):
  61. >>> # Implement customized client training method, which overwrites the default training method.
  62. >>> pass
  63. """
  64. def __init__(self,
  65. cid,
  66. conf,
  67. train_data,
  68. test_data,
  69. device,
  70. sleep_time=0,
  71. is_remote=False,
  72. local_port=23000,
  73. server_addr="localhost:22999",
  74. tracker_addr="localhost:12666"):
  75. self.cid = cid
  76. self.conf = conf
  77. self.train_data = train_data
  78. self.train_loader = None
  79. self.test_data = test_data
  80. self.test_loader = None
  81. self.device = device
  82. self.round_time = 0
  83. self.train_time = 0
  84. self.test_time = 0
  85. self.train_accuracy = []
  86. self.train_loss = []
  87. self.test_accuracy = 0
  88. self.test_loss = 0
  89. self.profiled = False
  90. self._sleep_time = sleep_time
  91. self.compressed_model = None
  92. self.model = None
  93. self._upload_holder = server_pb.UploadContent()
  94. self.is_remote = is_remote
  95. self.local_port = local_port
  96. self._server_addr = server_addr
  97. self._tracker_addr = tracker_addr
  98. self._server_stub = None
  99. self._tracker = None
  100. self._is_train = True
  101. if conf.track:
  102. self._tracker = init_tracking(init_store=False)
  103. def run_train(self, model, conf):
  104. """Conduct training on clients.
  105. Args:
  106. model (nn.Module): Model to train.
  107. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  108. Returns:
  109. :obj:`UploadRequest`: Training contents. Unify the interface for both local and remote operations.
  110. """
  111. self.conf = conf
  112. if conf.track:
  113. self._tracker.set_client_context(conf.task_id, conf.round_id, self.cid)
  114. self._is_train = True
  115. self.download(model)
  116. self.track(metric.TRAIN_DOWNLOAD_SIZE, model_size(model))
  117. self.decompression()
  118. self.pre_train()
  119. self.train(conf, self.device)
  120. self.post_train()
  121. self.track(metric.TRAIN_ACCURACY, self.train_accuracy)
  122. self.track(metric.TRAIN_LOSS, self.train_loss)
  123. self.track(metric.TRAIN_TIME, self.train_time)
  124. if conf.local_test:
  125. self.test_local()
  126. self.compression()
  127. self.track(metric.TRAIN_UPLOAD_SIZE, model_size(self.compressed_model))
  128. self.encryption()
  129. return self.upload()
  130. def run_test(self, model, conf):
  131. """Conduct testing on clients.
  132. Args:
  133. model (nn.Module): Model to test.
  134. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  135. Returns:
  136. :obj:`UploadRequest`: Testing contents. Unify the interface for both local and remote operations.
  137. """
  138. self.conf = conf
  139. if conf.track:
  140. reset = not self._is_train
  141. self._tracker.set_client_context(conf.task_id, conf.round_id, self.cid, reset_client=reset)
  142. self._is_train = False
  143. self.download(model)
  144. self.track(metric.TEST_DOWNLOAD_SIZE, model_size(model))
  145. self.decompression()
  146. self.pre_test()
  147. self.test(conf, self.device)
  148. self.post_test()
  149. self.track(metric.TEST_ACCURACY, float(self.test_accuracy))
  150. self.track(metric.TEST_LOSS, float(self.test_loss))
  151. self.track(metric.TEST_TIME, self.test_time)
  152. return self.upload()
  153. def download(self, model):
  154. """Download model from the server.
  155. Args:
  156. model (nn.Module): Global model distributed from the server.
  157. """
  158. if self.compressed_model:
  159. self.compressed_model.load_state_dict(model.state_dict())
  160. else:
  161. self.compressed_model = copy.deepcopy(model)
  162. def decompression(self):
  163. """Decompressed model. It can be further implemented when the model is compressed in the server."""
  164. self.model = self.compressed_model
  165. def pre_train(self):
  166. """Preprocessing before training."""
  167. pass
  168. def train(self, conf, device=CPU):
  169. """Execute client training.
  170. Args:
  171. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  172. device (str): Hardware device for training, cpu or cuda devices.
  173. """
  174. start_time = time.time()
  175. loss_fn, optimizer = self.pretrain_setup(conf, device)
  176. self.train_loss = []
  177. for i in range(conf.local_epoch):
  178. batch_loss = []
  179. for batched_x, batched_y in self.train_loader:
  180. x, y = batched_x.to(device), batched_y.to(device)
  181. optimizer.zero_grad()
  182. out = self.model(x)
  183. loss = loss_fn(out, y)
  184. loss.backward()
  185. optimizer.step()
  186. batch_loss.append(loss.item())
  187. current_epoch_loss = sum(batch_loss) / len(batch_loss)
  188. self.train_loss.append(float(current_epoch_loss))
  189. logger.debug("Client {}, local epoch: {}, loss: {}".format(self.cid, i, current_epoch_loss))
  190. self.train_time = time.time() - start_time
  191. logger.debug("Client {}, Train Time: {}".format(self.cid, self.train_time))
  192. def post_train(self):
  193. """Postprocessing after training."""
  194. pass
  195. def pretrain_setup(self, conf, device):
  196. """Setup loss function and optimizer before training."""
  197. self.simulate_straggler()
  198. self.model.train()
  199. self.model.to(device)
  200. loss_fn = self.load_loss_fn(conf)
  201. optimizer = self.load_optimizer(conf)
  202. if self.train_loader is None:
  203. self.train_loader = self.load_loader(conf)
  204. return loss_fn, optimizer
  205. def load_loss_fn(self, conf):
  206. return torch.nn.CrossEntropyLoss()
  207. def load_optimizer(self, conf):
  208. """Load training optimizer. Implemented Adam and SGD."""
  209. if conf.optimizer.type == "Adam":
  210. optimizer = torch.optim.Adam(self.model.parameters(), lr=conf.optimizer.lr)
  211. else:
  212. # default using optimizer SGD
  213. optimizer = torch.optim.SGD(self.model.parameters(),
  214. lr=conf.optimizer.lr,
  215. momentum=conf.optimizer.momentum,
  216. weight_decay=conf.optimizer.weight_decay)
  217. return optimizer
  218. def load_loader(self, conf):
  219. """Load the training data loader.
  220. Args:
  221. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  222. Returns:
  223. torch.utils.data.DataLoader: Data loader.
  224. """
  225. return self.train_data.loader(conf.batch_size, self.cid, shuffle=True, seed=conf.seed)
  226. def test_local(self):
  227. """Test client local model after training."""
  228. pass
  229. def pre_test(self):
  230. """Preprocessing before testing."""
  231. pass
  232. def test(self, conf, device=CPU):
  233. """Execute client testing.
  234. Args:
  235. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  236. device (str): Hardware device for training, cpu or cuda devices.
  237. """
  238. begin_test_time = time.time()
  239. self.model.eval()
  240. self.model.to(device)
  241. loss_fn = self.load_loss_fn(conf)
  242. if self.test_loader is None:
  243. self.test_loader = self.test_data.loader(conf.test_batch_size, self.cid, shuffle=False, seed=conf.seed)
  244. # TODO: make evaluation metrics a separate package and apply it here.
  245. self.test_loss = 0
  246. correct = 0
  247. with torch.no_grad():
  248. for batched_x, batched_y in self.test_loader:
  249. x = batched_x.to(device)
  250. y = batched_y.to(device)
  251. log_probs = self.model(x)
  252. loss = loss_fn(log_probs, y)
  253. _, y_pred = torch.max(log_probs, -1)
  254. correct += y_pred.eq(y.data.view_as(y_pred)).long().cpu().sum()
  255. self.test_loss += loss.item()
  256. test_size = self.test_data.size(self.cid)
  257. self.test_loss /= test_size
  258. self.test_accuracy = 100.0 * float(correct) / test_size
  259. logger.debug('Client {}, testing -- Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
  260. self.cid, self.test_loss, correct, test_size, self.test_accuracy))
  261. self.test_time = time.time() - begin_test_time
  262. self.model = self.model.cpu()
  263. def post_test(self):
  264. """Postprocessing after testing."""
  265. pass
  266. def encryption(self):
  267. """Encrypt the client local model."""
  268. # TODO: encryption of model, remember to track encrypted model instead of compressed one after implementation.
  269. pass
  270. def compression(self):
  271. """Compress the client local model after training and before uploading to the server."""
  272. self.compressed_model = self.model
  273. def upload(self):
  274. """Upload the messages from client to the server.
  275. Returns:
  276. :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations.
  277. Only applicable for local training as remote training upload through a gRPC request.
  278. """
  279. request = self.construct_upload_request()
  280. if not self.is_remote:
  281. self.post_upload()
  282. return request
  283. self.upload_remotely(request)
  284. self.post_upload()
  285. def post_upload(self):
  286. """Postprocessing after uploading training/testing results."""
  287. pass
  288. def construct_upload_request(self):
  289. """Construct client upload request for training updates and testing results.
  290. Returns:
  291. :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations.
  292. """
  293. data = codec.marshal(server_pb.Performance(accuracy=self.test_accuracy, loss=self.test_loss))
  294. typ = common_pb.DATA_TYPE_PERFORMANCE
  295. try:
  296. if self._is_train:
  297. data = codec.marshal(copy.deepcopy(self.compressed_model))
  298. typ = common_pb.DATA_TYPE_PARAMS
  299. data_size = self.train_data.size(self.cid)
  300. else:
  301. data_size = 1 if not self.test_data else self.test_data.size(self.cid)
  302. except KeyError:
  303. # When the datasize cannot be get from dataset, default to use equal aggregate
  304. data_size = 1
  305. m = self._tracker.get_client_metric().to_proto() if self._tracker else common_pb.ClientMetric()
  306. return server_pb.UploadRequest(
  307. task_id=self.conf.task_id,
  308. round_id=self.conf.round_id,
  309. client_id=self.cid,
  310. content=server_pb.UploadContent(
  311. data=data,
  312. type=typ,
  313. data_size=data_size,
  314. metric=m,
  315. ),
  316. )
  317. def upload_remotely(self, request):
  318. """Send upload request to remote server via gRPC.
  319. Args:
  320. request (:obj:`UploadRequest`): Upload request.
  321. """
  322. start_time = time.time()
  323. self.connect_to_server()
  324. resp = self._server_stub.Upload(request)
  325. upload_time = time.time() - start_time
  326. m = metric.TRAIN_UPLOAD_TIME if self._is_train else metric.TEST_UPLOAD_TIME
  327. self.track(m, upload_time)
  328. logger.info("client upload time: {}s".format(upload_time))
  329. if resp.status.code == common_pb.SC_OK:
  330. logger.info("Uploaded remotely to the server successfully\n")
  331. else:
  332. logger.error("Failed to upload, code: {}, message: {}\n".format(resp.status.code, resp.status.message))
  333. # Functions for remote services.
  334. def start_service(self):
  335. """Start client service."""
  336. if self.is_remote:
  337. grpc_wrapper.start_service(grpc_wrapper.TYPE_CLIENT, ClientService(self), self.local_port)
  338. def connect_to_server(self):
  339. """Establish connection between the client and the server."""
  340. if self.is_remote and self._server_stub is None:
  341. self._server_stub = grpc_wrapper.init_stub(grpc_wrapper.TYPE_SERVER, self._server_addr)
  342. logger.info("Successfully connected to gRPC server {}".format(self._server_addr))
  343. def operate(self, model, conf, index, is_train=True):
  344. """A wrapper over operations (training/testing) on clients.
  345. Args:
  346. model (nn.Module): Model for operations.
  347. conf (omegaconf.dictconfig.DictConfig): Client configurations.
  348. index (int): Client index in the client list, for retrieving data. TODO: improvement.
  349. is_train (bool): The flag to indicate whether the operation is training, otherwise testing.
  350. """
  351. try:
  352. # Load the data index depending on server request
  353. self.cid = self.train_data.users[index]
  354. except IndexError:
  355. logger.error("Data index exceed the available data, abort training")
  356. return
  357. if self.conf.track and self._tracker is None:
  358. self._tracker = init_tracking(init_store=False)
  359. if is_train:
  360. logger.info("Train on data index {}, client: {}".format(index, self.cid))
  361. self.run_train(model, conf)
  362. else:
  363. logger.info("Test on data index {}, client: {}".format(index, self.cid))
  364. self.run_test(model, conf)
  365. # Functions for tracking.
  366. def track(self, metric_name, value):
  367. """Track a metric.
  368. Args:
  369. metric_name (str): The name of the metric.
  370. value (str|int|float|bool|dict|list): The value of the metric.
  371. """
  372. if not self.conf.track or self._tracker is None:
  373. logger.debug("Tracker not available, Tracking not supported")
  374. return
  375. self._tracker.track_client(metric_name, value)
  376. def save_metrics(self):
  377. """Save client metrics to database."""
  378. # TODO: not tested
  379. if self._tracker is None:
  380. logger.debug("Tracker not available, no saving")
  381. return
  382. self._tracker.save_client()
  383. # Functions for simulation.
  384. def simulate_straggler(self):
  385. """Simulate straggler effect of system heterogeneity."""
  386. if self._sleep_time > 0:
  387. time.sleep(self._sleep_time)