base.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. import argparse
  2. import concurrent.futures
  3. import copy
  4. import logging
  5. import os
  6. import threading
  7. import time
  8. import numpy as np
  9. import torch
  10. import torch.distributed as dist
  11. from omegaconf import OmegaConf
  12. from easyfl.communication import grpc_wrapper
  13. from easyfl.datasets import TEST_IN_SERVER
  14. from easyfl.distributed import grouping, reduce_models, reduce_models_only_params, \
  15. reduce_value, reduce_values, reduce_weighted_values, gather_value
  16. from easyfl.distributed.distributed import CPU, GREEDY_GROUPING
  17. from easyfl.pb import client_service_pb2 as client_pb
  18. from easyfl.pb import common_pb2 as common_pb
  19. from easyfl.protocol import codec
  20. from easyfl.registry.etcd_client import EtcdClient
  21. from easyfl.server import strategies
  22. from easyfl.server.service import ServerService
  23. from easyfl.tracking import metric
  24. from easyfl.tracking.client import init_tracking
  25. from easyfl.utils.float import rounding
  26. logger = logging.getLogger(__name__)
  27. # train and test params
  28. MODEL = "model"
  29. DATA_SIZE = "data_size"
  30. ACCURACY = "accuracy"
  31. LOSS = "loss"
  32. CLIENT_METRICS = "client_metrics"
  33. FEDERATED_AVERAGE = "FedAvg"
  34. EQUAL_AVERAGE = "equal"
  35. AGGREGATION_CONTENT_ALL = "all"
  36. AGGREGATION_CONTENT_PARAMS = "parameters"
  37. def create_argument_parser():
  38. """Create argument parser with arguments/configurations for starting server service.
  39. Returns:
  40. argparse.ArgumentParser: The parser with server service arguments.
  41. """
  42. parser = argparse.ArgumentParser(description='Federated Server')
  43. parser.add_argument('--local-port',
  44. type=int,
  45. default=22999,
  46. help='Listen port of the client')
  47. parser.add_argument('--tracker-addr',
  48. type=str,
  49. default="localhost:12666",
  50. help='Address of tracking service in [IP]:[PORT] format')
  51. parser.add_argument('--is-remote',
  52. type=bool,
  53. default=False,
  54. help='Whether start as a remote server.')
  55. return parser
  56. class BaseServer(object):
  57. """Default implementation of federated learning server.
  58. Args:
  59. conf (omegaconf.dictconfig.DictConfig): Configurations of EasyFL.
  60. test_data (:obj:`FederatedDataset`): Test dataset for centralized testing in server, optional.
  61. val_data (:obj:`FederatedDataset`): Validation dataset for centralized validation in server, optional.
  62. is_remote (bool): A flag to indicate whether start remote training.
  63. local_port (int): The port of remote server service.
  64. Override the class and functions to implement customized server.
  65. Example:
  66. >>> from easyfl.server import BaseServer
  67. >>> class CustomizedServer(BaseServer):
  68. >>> def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
  69. >>> super(CustomizedServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
  70. >>> pass # more initialization of attributes.
  71. >>>
  72. >>> def aggregation(self):
  73. >>> # Implement customized aggregation method, which overwrites the default aggregation method.
  74. >>> pass
  75. """
  76. def __init__(self,
  77. conf,
  78. test_data=None,
  79. val_data=None,
  80. is_remote=False,
  81. local_port=22999):
  82. self.conf = conf
  83. self.test_data = test_data
  84. self.val_data = val_data
  85. self.is_remote = is_remote
  86. self.local_port = local_port
  87. self._is_training = False
  88. self._should_stop = False
  89. self._current_round = -1
  90. self._client_uploads = {}
  91. self._model = None
  92. self._compressed_model = None
  93. self._clients = None
  94. self._etcd = None
  95. self.selected_clients = []
  96. self.grouped_clients = []
  97. self._server_metric = None
  98. self._round_time = None
  99. self._begin_train_time = None # training begin time for a round
  100. self._start_time = None # training start time for a task
  101. self.client_stubs = {}
  102. if self.conf.is_distributed:
  103. self.default_time = self.conf.resource_heterogeneous.initial_default_time
  104. self._cumulative_times = [] # cumulative training after each test
  105. self._accuracies = []
  106. self._condition = threading.Condition()
  107. self._tracker = None
  108. self.init_tracker()
  109. def start(self, model, clients):
  110. """Start federated learning process, including training and testing.
  111. Args:
  112. model (nn.Module): The model to train.
  113. clients (list[:obj:`BaseClient`]|list[str]): Available clients.
  114. Clients are actually client grpc addresses when in remote training.
  115. """
  116. # Setup
  117. self._start_time = time.time()
  118. self._reset()
  119. self.set_model(model)
  120. self.set_clients(clients)
  121. if self._should_track():
  122. self._tracker.create_task(self.conf.task_id, OmegaConf.to_container(self.conf))
  123. # Get initial testing accuracies
  124. if self.conf.server.test_all:
  125. if self._should_track():
  126. self._tracker.set_round(self._current_round)
  127. self.test()
  128. self.save_tracker()
  129. while not self.should_stop():
  130. self._round_time = time.time()
  131. self._current_round += 1
  132. self.print_("\n-------- round {} --------".format(self._current_round))
  133. # Train
  134. self.pre_train()
  135. self.train()
  136. self.post_train()
  137. # Test
  138. if self._do_every(self.conf.server.test_every, self._current_round, self.conf.server.rounds):
  139. self.pre_test()
  140. self.test()
  141. self.post_test()
  142. # Save Model
  143. self.save_model()
  144. self.track(metric.ROUND_TIME, time.time() - self._round_time)
  145. self.save_tracker()
  146. self.print_("Accuracies: {}".format(rounding(self._accuracies, 4)))
  147. self.print_("Cumulative training time: {}".format(rounding(self._cumulative_times, 2)))
  148. def stop(self):
  149. """Set the flag to indicate training should stop."""
  150. self._should_stop = True
  151. def pre_train(self):
  152. """Preprocessing before training."""
  153. pass
  154. def train(self):
  155. """Training process of federated learning."""
  156. self.print_("--- start training ---")
  157. self.selection(self._clients, self.conf.server.clients_per_round)
  158. self.grouping_for_distributed()
  159. self.compression()
  160. begin_train_time = time.time()
  161. self.distribution_to_train()
  162. self.aggregation()
  163. train_time = time.time() - begin_train_time
  164. self.print_("Server train time: {}".format(train_time))
  165. self.track(metric.TRAIN_TIME, train_time)
  166. def post_train(self):
  167. """Postprocessing after training."""
  168. pass
  169. def pre_test(self):
  170. """Preprocessing before testing."""
  171. pass
  172. def test(self):
  173. """Testing process of federated learning."""
  174. self.print_("--- start testing ---")
  175. test_begin_time = time.time()
  176. test_results = {metric.TEST_ACCURACY: 0, metric.TEST_LOSS: 0, metric.TEST_TIME: 0}
  177. if self.conf.test_mode == TEST_IN_SERVER:
  178. if self.is_primary_server():
  179. test_results = self.test_in_server(self.conf.device)
  180. else:
  181. test_results = self.test_in_client()
  182. test_results[metric.TEST_TIME] = time.time() - test_begin_time
  183. self.track_test_results(test_results)
  184. def post_test(self):
  185. """Postprocessing after testing."""
  186. pass
  187. def should_stop(self):
  188. """Check whether should stop training. Stops the training under two conditions:
  189. 1. Reach max number of training rounds
  190. 2. TODO: Accuracy higher than certain amount.
  191. Returns:
  192. bool: A flag to indicate whether should stop training.
  193. """
  194. if self._should_stop or (self.conf.server.rounds and self._current_round + 1 >= self.conf.server.rounds):
  195. self._is_training = False
  196. return True
  197. return False
  198. def test_in_client(self):
  199. """Conduct testing in clients.
  200. Currently, it supports testing on the selected clients for training.
  201. TODO: Add optionals to select clients for testing.
  202. Returns:
  203. dict: Test metrics, {"test_loss": value, "test_accuracy": value, "test_time": value}.
  204. """
  205. self.compression()
  206. self.distribution_to_test()
  207. return self.aggregation_test()
  208. def test_in_server(self, device=CPU):
  209. """Conduct testing in the server.
  210. Args:
  211. device (str): The hardware device to conduct testing, either cpu or cuda devices.
  212. Returns:
  213. dict: Test metrics, {"test_loss": value, "test_accuracy": value, "test_time": value}.
  214. """
  215. self._model.eval()
  216. self._model.to(device)
  217. test_loss = 0
  218. correct = 0
  219. loss_fn = torch.nn.CrossEntropyLoss().to(device)
  220. with torch.no_grad():
  221. for batched_x, batched_y in self.test_data.loader(self.conf.server.batch_size, seed=self.conf.seed):
  222. x = batched_x.to(device)
  223. y = batched_y.to(device)
  224. log_probs = self._model(x)
  225. loss = loss_fn(log_probs, y)
  226. _, y_pred = torch.max(log_probs, -1)
  227. correct += y_pred.eq(y.data.view_as(y_pred)).long().cpu().sum()
  228. test_loss += loss.item()
  229. test_data_size = self.test_data.size()
  230. test_loss /= test_data_size
  231. accuracy = 100.00 * correct / test_data_size
  232. test_results = {
  233. metric.TEST_ACCURACY: float(accuracy),
  234. metric.TEST_LOSS: float(test_loss)
  235. }
  236. return test_results
  237. # Client selection
  238. def selection(self, clients, clients_per_round):
  239. """Select a fraction of total clients for training.
  240. Two selection strategies are implemented: 1. random selection; 2. select the first K clients.
  241. Args:
  242. clients (list[:obj:`BaseClient`]|list[str]): Available clients.
  243. clients_per_round (int): Number of clients to participate in training each round.
  244. Returns:
  245. (list[:obj:`BaseClient`]|list[str]): The selected clients.
  246. """
  247. if clients_per_round > len(clients):
  248. logger.warning("Available clients for selection are smaller than required clients for each round")
  249. clients_per_round = min(clients_per_round, len(clients))
  250. if self.conf.server.random_selection:
  251. np.random.seed(self._current_round)
  252. self.selected_clients = np.random.choice(clients, clients_per_round, replace=False)
  253. else:
  254. self.selected_clients = clients[:clients_per_round]
  255. return self.selected_clients
  256. def grouping_for_distributed(self):
  257. """Divide the selected clients into groups for distributed training.
  258. Each group of clients is assigned to conduct training in one GPU. The number of groups = the number of gpus.
  259. Not in distributed training, selected clients are in the same group.
  260. In distributed, selected clients are grouped with different strategies: greedy and random.
  261. """
  262. if self.conf.is_distributed:
  263. groups = grouping(self.selected_clients,
  264. self.conf.distributed.world_size,
  265. self.default_time,
  266. self.conf.resource_heterogeneous.grouping_strategy,
  267. self._current_round)
  268. # assign a group for each rank to train with current device.
  269. self.grouped_clients = groups[self.conf.distributed.rank]
  270. grouping_info = [(c.cid, c.round_time) for c in self.grouped_clients]
  271. logger.info("Grouping Result for rank {}: {}".format(self.conf.distributed.rank, grouping_info))
  272. else:
  273. self.grouped_clients = self.selected_clients
  274. rank = 0 if len(self.grouped_clients) == len(self.selected_clients) else self.conf.distributed.rank
  275. def compression(self):
  276. """Model compression to reduce communication cost."""
  277. self._compressed_model = self._model
  278. def distribution_to_train(self):
  279. """Distribute model and configurations to selected clients to train."""
  280. if self.is_remote:
  281. self.distribution_to_train_remotely()
  282. else:
  283. self.distribution_to_train_locally()
  284. # Adaptively update the training time of clients for greedy grouping.
  285. if self.conf.is_distributed and self.conf.resource_heterogeneous.grouping_strategy == GREEDY_GROUPING:
  286. self.profile_training_speed()
  287. self.update_default_time()
  288. def distribution_to_train_locally(self):
  289. """Conduct training sequentially for selected clients in the group."""
  290. uploaded_models = {}
  291. uploaded_weights = {}
  292. uploaded_metrics = []
  293. for client in self.grouped_clients:
  294. # Update client config before training
  295. self.conf.client.task_id = self.conf.task_id
  296. self.conf.client.round_id = self._current_round
  297. uploaded_request = client.run_train(self._compressed_model, self.conf.client)
  298. uploaded_content = uploaded_request.content
  299. model = self.decompression(codec.unmarshal(uploaded_content.data))
  300. uploaded_models[client.cid] = model
  301. uploaded_weights[client.cid] = uploaded_content.data_size
  302. uploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric))
  303. self.set_client_uploads_train(uploaded_models, uploaded_weights, uploaded_metrics)
  304. def distribution_to_train_remotely(self):
  305. """Distribute training requests to remote clients through multiple threads.
  306. The main thread waits for signal to proceed. The signal can be triggered via notification, as below example.
  307. Example to trigger signal:
  308. >>> with self.condition():
  309. >>> self.notify_all()
  310. """
  311. start_time = time.time()
  312. should_track = self._tracker is not None and self.conf.client.track
  313. with concurrent.futures.ThreadPoolExecutor() as executor:
  314. for client in self.grouped_clients:
  315. request = client_pb.OperateRequest(
  316. type=client_pb.OP_TYPE_TRAIN,
  317. model=codec.marshal(self._compressed_model),
  318. data_index=client.index,
  319. config=client_pb.OperateConfig(
  320. batch_size=self.conf.client.batch_size,
  321. local_epoch=self.conf.client.local_epoch,
  322. seed=self.conf.seed,
  323. local_test=self.conf.client.local_test,
  324. optimizer=client_pb.Optimizer(
  325. type=self.conf.client.optimizer.type,
  326. lr=self.conf.client.optimizer.lr,
  327. momentum=self.conf.client.optimizer.momentum,
  328. ),
  329. task_id=self.conf.task_id,
  330. round_id=self._current_round,
  331. track=should_track,
  332. ),
  333. )
  334. executor.submit(self._distribution_remotely, client.client_id, request)
  335. distribute_time = time.time() - start_time
  336. self.track(metric.TRAIN_DISTRIBUTE_TIME, distribute_time)
  337. logger.info("Distribute to clients, time: {}".format(distribute_time))
  338. with self._condition:
  339. self._condition.wait()
  340. def distribution_to_test(self):
  341. """Distribute to conduct testing on clients."""
  342. if self.is_remote:
  343. self.distribution_to_test_remotely()
  344. else:
  345. self.distribution_to_test_locally()
  346. def distribution_to_test_locally(self):
  347. """Conduct testing sequentially for selected testing clients."""
  348. uploaded_accuracies = []
  349. uploaded_losses = []
  350. uploaded_data_sizes = []
  351. uploaded_metrics = []
  352. test_clients = self.get_test_clients()
  353. for client in test_clients:
  354. # Update client config before testing
  355. self.conf.client.task_id = self.conf.task_id
  356. self.conf.client.round_id = self._current_round
  357. uploaded_request = client.run_test(self._compressed_model, self.conf.client)
  358. uploaded_content = uploaded_request.content
  359. performance = codec.unmarshal(uploaded_content.data)
  360. uploaded_accuracies.append(performance.accuracy)
  361. uploaded_losses.append(performance.loss)
  362. uploaded_data_sizes.append(uploaded_content.data_size)
  363. uploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric))
  364. self.set_client_uploads_test(uploaded_accuracies, uploaded_losses, uploaded_data_sizes, uploaded_metrics)
  365. def distribution_to_test_remotely(self):
  366. """Distribute testing requests to remote clients through multiple threads.
  367. The main thread waits for signal to proceed. The signal can be triggered via notification, as below example.
  368. Example to trigger signal:
  369. >>> with self.condition():
  370. >>> self.notify_all()
  371. """
  372. start_time = time.time()
  373. should_track = self._tracker is not None and self.conf.client.track
  374. test_clients = self.get_test_clients()
  375. with concurrent.futures.ThreadPoolExecutor() as executor:
  376. for client in test_clients:
  377. request = client_pb.OperateRequest(
  378. type=client_pb.OP_TYPE_TEST,
  379. model=codec.marshal(self._compressed_model),
  380. data_index=client.index,
  381. config=client_pb.OperateConfig(
  382. batch_size=self.conf.client.batch_size,
  383. test_batch_size=self.conf.client.test_batch_size,
  384. seed=self.conf.seed,
  385. task_id=self.conf.task_id,
  386. round_id=self._current_round,
  387. track=should_track,
  388. )
  389. )
  390. executor.submit(self._distribution_remotely, client.client_id, request)
  391. distribute_time = time.time() - start_time
  392. self.track(metric.TEST_DISTRIBUTE_TIME, distribute_time)
  393. logger.info("Distribute to test clients, time: {}".format(distribute_time))
  394. with self._condition:
  395. self._condition.wait()
  396. def get_test_clients(self):
  397. """Get clients to run testing.
  398. Returns:
  399. (list[:obj:`BaseClient`]|list[str]): Clients to test.
  400. """
  401. if self.conf.server.test_all:
  402. if self.conf.is_distributed:
  403. # Group and assign clients to different hardware devices to test.
  404. test_clients = grouping(self._clients,
  405. self.conf.distributed.world_size,
  406. default_time=self.default_time,
  407. strategy=self.conf.resource_heterogeneous.grouping_strategy)
  408. test_clients = test_clients[self.conf.distributed.rank]
  409. else:
  410. test_clients = self._clients
  411. else:
  412. # For the initial testing, if no clients are selected, test all clients
  413. test_clients = self.grouped_clients if self.grouped_clients is not None else self._clients
  414. return test_clients
  415. def _distribution_remotely(self, cid, request):
  416. """Distribute request to the assigned client to conduct operations.
  417. Args:
  418. cid (str): Client id.
  419. request (:obj:`OperateRequest`): gRPC request of specific operations.
  420. """
  421. resp = self.client_stubs[cid].Operate(request)
  422. if resp.status.code != common_pb.SC_OK:
  423. logger.error("Failed to train/test in client {}, error: {}".format(cid, resp.status.message))
  424. else:
  425. logger.info("Distribute to train/test remotely successfully, client: {}".format(cid))
  426. def aggregation_test(self):
  427. """Aggregate testing results from clients.
  428. Returns:
  429. dict: Test metrics, format in {"test_loss": value, "test_accuracy": value}
  430. """
  431. accuracies = self._client_uploads[ACCURACY]
  432. losses = self._client_uploads[LOSS]
  433. test_sizes = self._client_uploads[DATA_SIZE]
  434. if self.conf.test_method == "average":
  435. loss = self._mean_value(losses)
  436. accuracy = self._mean_value(accuracies)
  437. elif self.conf.test_method == "weighted":
  438. loss = self._weighted_value(losses, test_sizes)
  439. accuracy = self._weighted_value(accuracies, test_sizes)
  440. else:
  441. raise ValueError("test_method not supported, please use average or weighted")
  442. test_results = {
  443. metric.TEST_ACCURACY: float(accuracy),
  444. metric.TEST_LOSS: float(loss)
  445. }
  446. return test_results
  447. def _mean_value(self, values):
  448. if self.conf.is_distributed:
  449. return reduce_values(values, self.conf.device)
  450. else:
  451. return np.mean(values)
  452. def _weighted_value(self, values, weights):
  453. if self.conf.is_distributed:
  454. return reduce_weighted_values(values, weights, self.conf.device)
  455. else:
  456. return np.average(values, weights=weights)
  457. def decompression(self, model):
  458. """Decompression the models from clients"""
  459. return model
  460. def aggregation(self):
  461. """Aggregate training updates from clients.
  462. Server aggregates trained models from clients via federated averaging.
  463. """
  464. uploaded_content = self.get_client_uploads()
  465. models = list(uploaded_content[MODEL].values())
  466. weights = list(uploaded_content[DATA_SIZE].values())
  467. model = self.aggregate(models, weights)
  468. self.set_model(model, load_dict=True)
  469. def aggregate(self, models, weights):
  470. """Aggregate models uploaded from clients via federated averaging.
  471. Args:
  472. models (list[nn.Module]): List of models.
  473. weights (list[float]): List of weights, corresponding to each model.
  474. Weights are dataset size of clients by default.
  475. Returns
  476. nn.Module: Aggregated model.
  477. """
  478. if self.conf.server.aggregation_strategy == EQUAL_AVERAGE:
  479. weights = [1 for _ in range(len(models))]
  480. fn_average = strategies.federated_averaging
  481. fn_sum = strategies.weighted_sum
  482. fn_reduce = reduce_models
  483. if self.conf.server.aggregation_content == AGGREGATION_CONTENT_PARAMS:
  484. fn_average = strategies.federated_averaging_only_params
  485. fn_sum = strategies.weighted_sum_only_params
  486. fn_reduce = reduce_models_only_params
  487. if self.conf.is_distributed:
  488. dist.barrier()
  489. model, sample_sum = fn_sum(models, weights)
  490. fn_reduce(model, torch.tensor(sample_sum).to(self.conf.device))
  491. else:
  492. model = fn_average(models, weights)
  493. return model
  494. def _reset(self):
  495. self._current_round = -1
  496. self._should_stop = False
  497. self._is_training = True
  498. def is_training(self):
  499. """Check whether the server is in training or has stopped training.
  500. Returns:
  501. bool: A flag to indicate whether server is in training.
  502. """
  503. return self._is_training
  504. def set_model(self, model, load_dict=False):
  505. """Update the universal model in the server.
  506. Args:
  507. model (nn.Module): New model.
  508. load_dict (bool): A flag to indicate whether load state dict or copy the model.
  509. """
  510. if load_dict:
  511. self._model.load_state_dict(model.state_dict())
  512. else:
  513. self._model = copy.deepcopy(model)
  514. def set_clients(self, clients):
  515. self._clients = clients
  516. def num_of_clients(self):
  517. return len(self._clients)
  518. def save_model(self):
  519. """Save the model in the server."""
  520. if self._do_every(self.conf.server.save_model_every, self._current_round, self.conf.server.rounds) and \
  521. self.is_primary_server():
  522. save_path = self.conf.server.save_model_path
  523. if save_path == "":
  524. save_path = os.path.join(os.getcwd(), "saved_models")
  525. os.makedirs(save_path, exist_ok=True)
  526. save_path = os.path.join(save_path,
  527. "{}_global_model_r_{}.pth".format(self.conf.task_id, self._current_round))
  528. torch.save(self._model.cpu().state_dict(), save_path)
  529. self.print_("Model saved at {}".format(save_path))
  530. def set_client_uploads_train(self, models, weights, metrics=None):
  531. """Set training updates uploaded from clients.
  532. Args:
  533. models (dict): A collection of models.
  534. weights (dict): A collection of weights.
  535. metrics (dict): Client training metrics.
  536. """
  537. self.set_client_uploads(MODEL, models)
  538. self.set_client_uploads(DATA_SIZE, weights)
  539. if self._should_gather_metrics():
  540. metrics = self.gather_client_train_metrics()
  541. self.set_client_uploads(CLIENT_METRICS, metrics)
  542. def set_client_uploads_test(self, accuracies, losses, test_sizes, metrics=None):
  543. """Set testing results uploaded from clients.
  544. Args:
  545. accuracies (list[float]): Testing accuracies of clients.
  546. losses (list[float]): Testing losses of clients.
  547. test_sizes (list[float]): Test dataset sizes of clients.
  548. metrics (dict): Client testing metrics.
  549. """
  550. self.set_client_uploads(ACCURACY, accuracies)
  551. self.set_client_uploads(LOSS, losses)
  552. self.set_client_uploads(DATA_SIZE, test_sizes)
  553. if self._should_gather_metrics() and CLIENT_METRICS in self._client_uploads:
  554. train_metrics = self.get_client_uploads()[CLIENT_METRICS]
  555. metrics = metric.ClientMetric.merge_train_to_test_metrics(train_metrics, metrics)
  556. self.set_client_uploads(CLIENT_METRICS, metrics)
  557. def set_client_uploads(self, key, value):
  558. """A general function to set uploaded content from clients.
  559. Args:
  560. key (str): Dictionary key.
  561. value (*): Uploaded content.
  562. """
  563. self._client_uploads[key] = value
  564. def get_client_uploads(self):
  565. """Get client uploaded contents.
  566. Returns:
  567. dict: A dictionary that contains client uploaded contents.
  568. """
  569. return self._client_uploads
  570. def _do_every(self, every, current_round, rounds):
  571. return (current_round + 1) % every == 0 or (current_round + 1) == rounds
  572. def print_(self, content):
  573. """print only the server is primary server.
  574. Args:
  575. content (str): The content to log.
  576. """
  577. if self.is_primary_server():
  578. logger.info(content)
  579. def is_primary_server(self):
  580. """Check whether the current process is the primary server.
  581. In standalone or remote training, the server is primary.
  582. In distributed training, the server on rank0 is primary.
  583. Returns:
  584. bool: A flag to indicate whether current process is the primary server.
  585. """
  586. return not self.conf.is_distributed or self.conf.distributed.rank == 0
  587. # Functions for remote training
  588. def start_service(self):
  589. """Start federated learning server GRPC service."""
  590. if self.is_remote:
  591. grpc_wrapper.start_service(grpc_wrapper.TYPE_SERVER, ServerService(self), self.local_port)
  592. logger.info("GRPC server started at :{}".format(self.local_port))
  593. def connect_remote_clients(self, clients):
  594. # TODO: This client should be consistent with client started separately.
  595. for client in clients:
  596. if client.client_id not in self.client_stubs:
  597. self.client_stubs[client.client_id] = grpc_wrapper.init_stub(grpc_wrapper.TYPE_CLIENT, client.address)
  598. logger.info("Successfully connected to gRPC client {}".format(client.address))
  599. def init_etcd(self, addresses):
  600. """Initialize etcd as the registry for client registration.
  601. Args:
  602. addresses (str): The etcd addresses split by ","
  603. """
  604. self._etcd = EtcdClient("server", addresses, "backends")
  605. def start_remote_training(self, model, clients):
  606. """Start federated learning in the remote training mode.
  607. Server establishes gPRC connection with clients that are not connected first before training.
  608. Args:
  609. model (nn.Module): The model to train.
  610. clients (list[str]): Client addresses.
  611. """
  612. self.connect_remote_clients(clients)
  613. self.start(model, clients)
  614. # Functions for tracking
  615. def init_tracker(self):
  616. """Initialize tracking"""
  617. if self.conf.server.track:
  618. self._tracker = init_tracking(self.conf.tracking.database, self.conf.tracker_addr)
  619. def track(self, metric_name, value):
  620. """Track a metric.
  621. Args:
  622. metric_name (str): Name of the metric of a round.
  623. value (str|int|float|bool|dict|list): Value of the metric.
  624. """
  625. if not self._should_track():
  626. return
  627. self._tracker.track_round(metric_name, value)
  628. def track_test_results(self, results):
  629. """Track test results collected from clients.
  630. Args:
  631. results (dict): Test metrics, format in {"test_loss": value, "test_accuracy": value, "test_time": value}
  632. """
  633. self._cumulative_times.append(time.time() - self._start_time)
  634. self._accuracies.append(results[metric.TEST_ACCURACY])
  635. for metric_name in results:
  636. self.track(metric_name, results[metric_name])
  637. self.print_('Test time {:.2f}s, Test loss: {:.2f}, Test accuracy: {:.2f}%'.format(
  638. results[metric.TEST_TIME], results[metric.TEST_LOSS], results[metric.TEST_ACCURACY]))
  639. def save_tracker(self):
  640. """Save metrics in the tracker to database."""
  641. if self._tracker:
  642. self.track_communication_cost()
  643. if self.is_primary_server():
  644. self._tracker.save_round()
  645. # In distributed training, each server saves their clients separately.
  646. self._tracker.save_clients(self._client_uploads[CLIENT_METRICS])
  647. def track_communication_cost(self):
  648. """Track communication cost among server and clients.
  649. Communication cost occurs in `training` and `testing` with downlink and uplink costs.
  650. """
  651. train_upload_size = 0
  652. train_download_size = 0
  653. test_upload_size = 0
  654. test_download_size = 0
  655. for client_metric in self._client_uploads[CLIENT_METRICS]:
  656. if client_metric.round_id == self._current_round and client_metric.task_id == self.conf.task_id:
  657. train_upload_size += client_metric.train_upload_size
  658. train_download_size += client_metric.train_download_size
  659. test_upload_size += client_metric.test_upload_size
  660. test_download_size += client_metric.test_download_size
  661. if self.conf.is_distributed:
  662. train_upload_size = reduce_value(train_upload_size, self.conf.device).item()
  663. train_download_size = reduce_value(train_download_size, self.conf.device).item()
  664. test_upload_size = reduce_value(test_upload_size, self.conf.device).item()
  665. test_download_size = reduce_value(test_download_size, self.conf.device).item()
  666. self._tracker.track_round(metric.TRAIN_UPLOAD_SIZE, train_upload_size)
  667. self._tracker.track_round(metric.TRAIN_DOWNLOAD_SIZE, train_download_size)
  668. self._tracker.track_round(metric.TEST_UPLOAD_SIZE, test_upload_size)
  669. self._tracker.track_round(metric.TEST_DOWNLOAD_SIZE, test_download_size)
  670. def _should_track(self):
  671. """Check whether server should track metrics.
  672. Server tracks metrics only when tracking is enabled and it is the primary server.
  673. Returns:
  674. bool: A flag indicate whether server should track metrics.
  675. """
  676. return self._tracker is not None and self.is_primary_server()
  677. def _should_gather_metrics(self):
  678. """Check whether the server should gather metrics from GPUs.
  679. Gather metrics only when testing all in `distributed` training.
  680. Testing all resets clients' training metrics, thus,
  681. server needs to gather train metrics to construct full client metrics.
  682. Returns:
  683. bool: A flag indicate whether server should gather metrics.
  684. """
  685. return self.conf.is_distributed and self.conf.server.test_all and self._tracker
  686. def gather_client_train_metrics(self):
  687. """Gather client train metrics from other ranks for distributed training, when testing all clients (test_all).
  688. When testing all clients, the trained metrics may be override by the test metrics
  689. because clients may be placed in different GPUs in training and testing, leading to losses of train metrics.
  690. So we gather train metrics and set them in test metrics.
  691. TODO: gather is not progressing. Need fix.
  692. """
  693. world_size = self.conf.distributed.world_size
  694. device = self.conf.device
  695. uploads = self.get_client_uploads()
  696. client_id_list = []
  697. train_accuracy_list = []
  698. train_loss_list = []
  699. train_time_list = []
  700. train_upload_time_list = []
  701. train_upload_size_list = []
  702. train_download_size_list = []
  703. for m in uploads[CLIENT_METRICS]:
  704. # client_id_list += gather_value(m.client_id, world_size, device).tolist()
  705. train_accuracy_list += gather_value(m.train_accuracy, world_size, device)
  706. train_loss_list += gather_value(m.train_loss, world_size, device)
  707. train_time_list += gather_value(m.train_time, world_size, device)
  708. train_upload_time_list += gather_value(m.train_upload_time, world_size, device)
  709. train_upload_size_list += gather_value(m.train_upload_size, world_size, device)
  710. train_download_size_list += gather_value(m.train_download_size, world_size, device)
  711. metrics = []
  712. # Note: Client id may not match with its training stats because all_gather string is not supported.
  713. client_id_list = [c.cid for c in self.selected_clients]
  714. for i, client_id in enumerate(client_id_list):
  715. m = metric.ClientMetric(self.conf.task_id, self._current_round, client_id)
  716. m.add(metric.TRAIN_ACCURACY, train_accuracy_list[i])
  717. m.add(metric.TRAIN_LOSS, train_loss_list[i])
  718. m.add(metric.TRAIN_TIME, train_time_list[i])
  719. m.add(metric.TRAIN_UPLOAD_TIME, train_upload_time_list[i])
  720. m.add(metric.TRAIN_UPLOAD_SIZE, train_upload_size_list[i])
  721. m.add(metric.TRAIN_DOWNLOAD_SIZE, train_download_size_list[i])
  722. metrics.append(m)
  723. return metrics
  724. # Functions for remote training.
  725. def condition(self):
  726. return self._condition
  727. def notify_all(self):
  728. self._condition.notify_all()
  729. # Functions for distributed training optimization.
  730. def profile_training_speed(self):
  731. """Manage profiling of client training speeds for distributed training optimization."""
  732. profile_required = []
  733. for client in self.selected_clients:
  734. if not client.profiled:
  735. profile_required.append(client)
  736. if len(profile_required) > 0:
  737. original = torch.FloatTensor([c.round_time for c in profile_required]).to(self.conf.device)
  738. time_update = torch.FloatTensor([c.train_time for c in profile_required]).to(self.conf.device)
  739. dist.barrier()
  740. dist.all_reduce(time_update)
  741. for i in range(len(profile_required)):
  742. old_round_time = original[i]
  743. current_round_time = time_update[i]
  744. if old_round_time == 0 or self._should_update_round_time(old_round_time, current_round_time):
  745. profile_required[i].round_time = float(current_round_time)
  746. profile_required[i].train_time = 0
  747. else:
  748. profile_required[i].profiled = True
  749. def update_default_time(self):
  750. """Update the estimated default training time of clients using actual training time from profiled clients."""
  751. default_momentum = self.conf.resource_heterogeneous.default_time_momentum
  752. current_round_average = np.mean([float(c.round_time) for c in self.selected_clients])
  753. self.default_time = default_momentum * current_round_average + self.default_time * (1 - default_momentum)
  754. def _should_update_round_time(self, old_round_time, new_round_time, threshold=0.3):
  755. """Check whether assign a new estimated round time to client or set it to profiled.
  756. Args:
  757. old_round_time (float): previous estimated round time.
  758. new_round_time (float): Currently profiled round time.
  759. threshold (float): Tolerance threshold of difference between old and new times.
  760. Returns:
  761. bool: A flag to indicate whether should update round time.
  762. """
  763. if new_round_time < old_round_time:
  764. return ((old_round_time - new_round_time) / new_round_time) >= threshold
  765. else:
  766. return ((new_round_time - old_round_time) / old_round_time) >= threshold