coordinator.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. import logging
  2. import os
  3. import random
  4. import sys
  5. import time
  6. from os import path
  7. import numpy as np
  8. import torch
  9. from omegaconf import OmegaConf
  10. from easyfl.client.base import BaseClient
  11. from easyfl.datasets import TEST_IN_SERVER
  12. from easyfl.datasets.data import construct_datasets
  13. from easyfl.distributed import dist_init, get_device
  14. from easyfl.models.model import load_model
  15. from easyfl.server.base import BaseServer
  16. from easyfl.simulation.system_hetero import resource_hetero_simulation
  17. logger = logging.getLogger(__name__)
  18. class Coordinator(object):
  19. """Coordinator manages federated learning server and client.
  20. A single instance of coordinator is initialized for each federated learning task
  21. when the package is imported.
  22. """
  23. def __init__(self):
  24. self.registered_model = False
  25. self.registered_dataset = False
  26. self.registered_server = False
  27. self.registered_client = False
  28. self.train_data = None
  29. self.test_data = None
  30. self.val_data = None
  31. self.conf = None
  32. self.model = None
  33. self._model_class = None
  34. self.server = None
  35. self._server_class = None
  36. self.clients = None
  37. self._client_class = None
  38. self.tracker = None
  39. def init(self, conf, init_all=True):
  40. """Initialize coordinator
  41. Args:
  42. conf (omegaconf.dictconfig.DictConfig): Internal configurations for federated learning.
  43. init_all (bool): Whether initialize dataset, model, server, and client other than configuration.
  44. """
  45. self.init_conf(conf)
  46. _set_random_seed(conf.seed)
  47. if init_all:
  48. self.init_dataset()
  49. self.init_model()
  50. self.init_server()
  51. self.init_clients()
  52. def run(self):
  53. """Run the coordinator and the federated learning process.
  54. Initialize `torch.distributed` if distributed training is configured.
  55. """
  56. start_time = time.time()
  57. if self.conf.is_distributed:
  58. dist_init(
  59. self.conf.distributed.backend,
  60. self.conf.distributed.init_method,
  61. self.conf.distributed.world_size,
  62. self.conf.distributed.rank,
  63. self.conf.distributed.local_rank,
  64. )
  65. self.server.start(self.model, self.clients)
  66. self.print_("Total training time {:.1f}s".format(time.time() - start_time))
  67. def init_conf(self, conf):
  68. """Initialize coordinator configuration.
  69. Args:
  70. conf (omegaconf.dictconfig.DictConfig): Configurations.
  71. """
  72. self.conf = conf
  73. self.conf.is_distributed = (self.conf.gpu > 1)
  74. if self.conf.gpu == 0:
  75. self.conf.device = "cpu"
  76. elif self.conf.gpu == 1:
  77. self.conf.device = 0
  78. else:
  79. self.conf.device = get_device(self.conf.gpu, self.conf.distributed.world_size,
  80. self.conf.distributed.local_rank)
  81. self.print_("Configurations: {}".format(self.conf))
  82. def init_dataset(self):
  83. """Initialize datasets. Use provided datasets if not registered."""
  84. if self.registered_dataset:
  85. return
  86. self.train_data, self.test_data = construct_datasets(self.conf.data.root,
  87. self.conf.data.dataset,
  88. self.conf.data.num_of_clients,
  89. self.conf.data.split_type,
  90. self.conf.data.min_size,
  91. self.conf.data.class_per_client,
  92. self.conf.data.data_amount,
  93. self.conf.data.iid_fraction,
  94. self.conf.data.user,
  95. self.conf.data.train_test_split,
  96. self.conf.data.weights,
  97. self.conf.data.alpha)
  98. self.print_(f"Total training data amount: {self.train_data.total_size()}")
  99. self.print_(f"Total testing data amount: {self.test_data.total_size()}")
  100. def init_model(self):
  101. """Initialize model instance."""
  102. if not self.registered_model:
  103. self._model_class = load_model(self.conf.model)
  104. # model_class is None means model is registered as instance, no need initialization
  105. if self._model_class:
  106. self.model = self._model_class()
  107. def init_server(self):
  108. """Initialize a server instance."""
  109. if not self.registered_server:
  110. self._server_class = BaseServer
  111. kwargs = {
  112. "is_remote": self.conf.is_remote,
  113. "local_port": self.conf.local_port
  114. }
  115. if self.conf.test_mode == TEST_IN_SERVER:
  116. kwargs["test_data"] = self.test_data
  117. if self.val_data:
  118. kwargs["val_data"] = self.val_data
  119. self.server = self._server_class(self.conf, **kwargs)
  120. def init_clients(self):
  121. """Initialize client instances, each represent a federated learning client."""
  122. if not self.registered_client:
  123. self._client_class = BaseClient
  124. # Enforce system heterogeneity of clients.
  125. sleep_time = [0 for _ in self.train_data.users]
  126. if self.conf.resource_heterogeneous.simulate:
  127. sleep_time = resource_hetero_simulation(self.conf.resource_heterogeneous.fraction,
  128. self.conf.resource_heterogeneous.hetero_type,
  129. self.conf.resource_heterogeneous.sleep_group_num,
  130. self.conf.resource_heterogeneous.level,
  131. self.conf.resource_heterogeneous.total_time,
  132. len(self.train_data.users))
  133. client_test_data = self.test_data
  134. if self.conf.test_mode == TEST_IN_SERVER:
  135. client_test_data = None
  136. self.clients = [self._client_class(u,
  137. self.conf.client,
  138. self.train_data,
  139. client_test_data,
  140. self.conf.device,
  141. **{"sleep_time": sleep_time[i]})
  142. for i, u in enumerate(self.train_data.users)]
  143. self.print_("Clients in total: {}".format(len(self.clients)))
  144. def init_client(self):
  145. """Initialize client instance.
  146. Returns:
  147. :obj:`BaseClient`: The initialized client instance.
  148. """
  149. if not self.registered_client:
  150. self._client_class = BaseClient
  151. # Get a random client if not specified
  152. if self.conf.index:
  153. user = self.train_data.users[self.conf.index]
  154. else:
  155. user = random.choice(self.train_data.users)
  156. return self._client_class(user,
  157. self.conf.client,
  158. self.train_data,
  159. self.test_data,
  160. self.conf.device,
  161. is_remote=self.conf.is_remote,
  162. local_port=self.conf.local_port,
  163. server_addr=self.conf.server_addr,
  164. tracker_addr=self.conf.tracker_addr)
  165. def start_server(self, args):
  166. """Start a server service for remote training.
  167. Server controls the model and testing dataset if configured to test in server.
  168. Args:
  169. args (argparse.Namespace): Configurations passed in as arguments, it is merged with configurations.
  170. """
  171. if args:
  172. self.conf = OmegaConf.merge(self.conf, args.__dict__)
  173. if self.conf.test_mode == TEST_IN_SERVER:
  174. self.init_dataset()
  175. self.init_model()
  176. self.init_server()
  177. self.server.start_service()
  178. def start_client(self, args):
  179. """Start a client service for remote training.
  180. Client controls training datasets.
  181. Args:
  182. args (argparse.Namespace): Configurations passed in as arguments, it is merged with configurations.
  183. """
  184. if args:
  185. self.conf = OmegaConf.merge(self.conf, args.__dict__)
  186. self.init_dataset()
  187. client = self.init_client()
  188. client.start_service()
  189. def register_dataset(self, train_data, test_data, val_data=None):
  190. """Register datasets.
  191. Datasets should inherit from :obj:`FederatedDataset`, e.g., :obj:`FederatedTensorDataset`.
  192. Args:
  193. train_data (:obj:`FederatedDataset`): Training dataset.
  194. test_data (:obj:`FederatedDataset`): Testing dataset.
  195. val_data (:obj:`FederatedDataset`): Validation dataset.
  196. """
  197. self.registered_dataset = True
  198. self.train_data = train_data
  199. self.test_data = test_data
  200. self.val_data = val_data
  201. def register_model(self, model):
  202. """Register customized model for federated learning.
  203. Args:
  204. model (nn.Module): PyTorch model, both class and instance are acceptable.
  205. Use model class when there is no specific arguments to initialize model.
  206. """
  207. self.registered_model = True
  208. if not isinstance(model, type):
  209. self.model = model
  210. else:
  211. self._model_class = model
  212. def register_server(self, server):
  213. """Register a customized federated learning server.
  214. Args:
  215. server (:obj:`BaseServer`): Customized federated learning server.
  216. """
  217. self.registered_server = True
  218. self._server_class = server
  219. def register_client(self, client):
  220. """Register a customized federated learning client.
  221. Args:
  222. client (:obj:`BaseClient`): Customized federated learning client.
  223. """
  224. self.registered_client = True
  225. self._client_class = client
  226. def print_(self, content):
  227. """Log the content only when the server is primary server.
  228. Args:
  229. content (str): The content to log.
  230. """
  231. if self._is_primary_server():
  232. logger.info(content)
  233. def _is_primary_server(self):
  234. """Check whether current running server is the primary server.
  235. In standalone or remote training, the server is primary.
  236. In distributed training, the server on `rank0` is primary.
  237. """
  238. return not self.conf.is_distributed or self.conf.distributed.rank == 0
  239. def _set_random_seed(seed):
  240. random.seed(seed)
  241. np.random.seed(seed)
  242. torch.manual_seed(seed)
  243. torch.cuda.manual_seed(seed)
  244. # Initialize the global coordinator object
  245. _global_coord = Coordinator()
  246. def init_conf(conf=None):
  247. """Initialize configuration for EasyFL. It overrides and supplements default configuration loaded from config.yaml
  248. with the provided configurations.
  249. Args:
  250. conf (dict): Configurations.
  251. Returns:
  252. omegaconf.dictconfig.DictConfig: Internal configurations managed by OmegaConf.
  253. """
  254. here = path.abspath(path.dirname(__file__))
  255. config_file = path.join(here, 'config.yaml')
  256. return load_config(config_file, conf)
  257. def load_config(file, conf=None):
  258. """Load and merge configuration from file and input
  259. Args:
  260. file (str): filename of the configuration.
  261. conf (dict): Configurations.
  262. Returns:
  263. omegaconf.dictconfig.DictConfig: Internal configurations managed by OmegaConf.
  264. """
  265. config = OmegaConf.load(file)
  266. if conf is not None:
  267. config = OmegaConf.merge(config, conf)
  268. return config
  269. def init_logger(log_level):
  270. """Initialize internal logger of EasyFL.
  271. Args:
  272. log_level (int): Logger level, e.g., logging.INFO, logging.DEBUG
  273. """
  274. log_formatter = logging.Formatter("%(asctime)s [%(threadName)s] [%(levelname)-5.5s] %(message)s")
  275. root_logger = logging.getLogger()
  276. log_level = logging.INFO if not log_level else log_level
  277. root_logger.setLevel(log_level)
  278. file_path = os.path.join(os.getcwd(), "logs")
  279. if not os.path.exists(file_path):
  280. os.makedirs(file_path)
  281. file_path = path.join(file_path, "train" + time.strftime(".%m_%d_%H_%M_%S") + ".log")
  282. file_handler = logging.FileHandler(file_path)
  283. file_handler.setFormatter(log_formatter)
  284. root_logger.addHandler(file_handler)
  285. console_handler = logging.StreamHandler(sys.stdout)
  286. console_handler.setFormatter(log_formatter)
  287. root_logger.addHandler(console_handler)
  288. def init(conf=None, init_all=True):
  289. """Initialize EasyFL.
  290. Args:
  291. conf (dict, optional): Configurations.
  292. init_all (bool, optional): Whether initialize dataset, model, server, and client other than configuration.
  293. """
  294. global _global_coord
  295. config = init_conf(conf)
  296. init_logger(config.tracking.log_level)
  297. _set_random_seed(config.seed)
  298. _global_coord.init(config, init_all)
  299. def run():
  300. """Run federated learning process."""
  301. global _global_coord
  302. _global_coord.run()
  303. def init_dataset():
  304. """Initialize dataset, either using registered dataset or out-of-the-box datasets set in config."""
  305. global _global_coord
  306. _global_coord.init_dataset()
  307. def init_model():
  308. """Initialize model, either using registered model or out-of–the-box model set in config.
  309. Returns:
  310. nn.Module: Model used in federated learning.
  311. """
  312. global _global_coord
  313. _global_coord.init_model()
  314. return _global_coord.model
  315. def start_server(args=None):
  316. """Start federated learning server service for remote training.
  317. Args:
  318. args (argparse.Namespace): Configurations passed in as arguments.
  319. """
  320. global _global_coord
  321. _global_coord.start_server(args)
  322. def start_client(args=None):
  323. """Start federated learning client service for remote training.
  324. Args:
  325. args (argparse.Namespace): Configurations passed in as arguments.
  326. """
  327. global _global_coord
  328. _global_coord.start_client(args)
  329. def get_coordinator():
  330. """Get the global coordinator instance.
  331. Returns:
  332. :obj:`Coordinator`: global coordinator instance.
  333. """
  334. return _global_coord
  335. def register_dataset(train_data, test_data, val_data=None):
  336. """Register datasets for federated learning training.
  337. Args:
  338. train_data (:obj:`FederatedDataset`): Training dataset.
  339. test_data (:obj:`FederatedDataset`): Testing dataset.
  340. val_data (:obj:`FederatedDataset`): Validation dataset.
  341. """
  342. global _global_coord
  343. _global_coord.register_dataset(train_data, test_data, val_data)
  344. def register_model(model):
  345. """Register model for federated learning training.
  346. Args:
  347. model (nn.Module): PyTorch model, both class and instance are acceptable.
  348. """
  349. global _global_coord
  350. _global_coord.register_model(model)
  351. def register_server(server):
  352. """Register federated learning server.
  353. Args:
  354. server (:obj:`BaseServer`): Customized federated learning server.
  355. """
  356. global _global_coord
  357. _global_coord.register_server(server)
  358. def register_client(client):
  359. """Register federated learning client.
  360. Args:
  361. client (:obj:`BaseClient`): Customized federated learning client.
  362. """
  363. global _global_coord
  364. _global_coord.register_client(client)