distributed.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import logging
  2. import numpy as np
  3. import torch
  4. import torch.distributed as dist
  5. logger = logging.getLogger(__name__)
  6. CPU = "cpu"
  7. RANDOMIZE_GROUPING = "random"
  8. GREEDY_GROUPING = "greedy"
  9. SLOWEST_GROUPING = "slowest"
  10. def reduce_models(model, sample_sum):
  11. """Aggregate models across devices and update the model with the new aggregated model parameters.
  12. Args:
  13. model (nn.Module): The model in a device to aggregate.
  14. sample_sum (int): Sum of the total dataset sizes of clients in a device.
  15. """
  16. dist.all_reduce(sample_sum, op=dist.ReduceOp.SUM)
  17. state = model.state_dict()
  18. for k in state.keys():
  19. dist.all_reduce(state[k], op=dist.ReduceOp.SUM)
  20. state[k] = torch.div(state[k], sample_sum)
  21. model.load_state_dict(state)
  22. def reduce_models_only_params(model, sample_sum):
  23. """Aggregate models across devices and update the model with the new aggregated model parameters,
  24. excluding the persistent buffers like BN stats.
  25. Args:
  26. model (nn.Module): The model in a device to aggregate.
  27. sample_sum (torch.Tensor): Sum of the total dataset sizes of clients in a device.
  28. """
  29. dist.all_reduce(sample_sum, op=dist.ReduceOp.SUM)
  30. for param in model.parameters():
  31. dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
  32. param.data = torch.div(param.data, sample_sum)
  33. def reduce_value(value, device):
  34. """Calculate the sum of the value across devices.
  35. Args:
  36. value (float/int): Value to sum.
  37. device (str): The device where the value is on, either cpu or cuda devices.
  38. Returns:
  39. torch.Tensor: Sum of the values.
  40. """
  41. v = torch.tensor(value).to(device)
  42. dist.all_reduce(v, op=dist.ReduceOp.SUM)
  43. return v
  44. def reduce_values(values, device):
  45. """Calculate the average of values across devices.
  46. Args:
  47. values (list[float|int]): Values to average.
  48. device (str): The device where the value is on, either cpu or cuda devices.
  49. Returns:
  50. torch.Tensor: The average of the values across devices.
  51. """
  52. length = torch.tensor(len(values)).to(device)
  53. total = torch.tensor(sum(values)).to(device)
  54. dist.all_reduce(length, op=dist.ReduceOp.SUM)
  55. dist.all_reduce(total, op=dist.ReduceOp.SUM)
  56. return torch.div(total, length)
  57. def reduce_weighted_values(values, weights, device):
  58. """Calculate the weighted average of values across devices.
  59. Args:
  60. values (list[float|int]): Values to average.
  61. weights (list[float|int]): The weights to calculate weighted average.
  62. device (str): The device where the value is on, either cpu or cuda devices.
  63. Returns:
  64. torch.Tensor: The average of values across devices.
  65. """
  66. values = torch.tensor(values).to(device)
  67. weights = torch.tensor(weights).to(device)
  68. total_weights = torch.sum(weights).to(device)
  69. weighted_sum = torch.sum(values * weights).to(device)
  70. dist.all_reduce(total_weights, op=dist.ReduceOp.SUM)
  71. dist.all_reduce(weighted_sum, op=dist.ReduceOp.SUM)
  72. return torch.div(weighted_sum, total_weights)
  73. def gather_value(value, world_size, device):
  74. """Gather the value from devices to a list.
  75. Args:
  76. value (float|int): The value to gather.
  77. world_size (int): The number of processes.
  78. device (str): The device where the value is on, either cpu or cuda devices.
  79. Returns:
  80. list[torch.Tensor]: A list of gathered values.
  81. """
  82. v = torch.tensor(value).to(device)
  83. target = [v.clone() for _ in range(world_size)]
  84. dist.all_gather(target, v)
  85. return target
  86. def grouping(clients, world_size, default_time=10, strategy=RANDOMIZE_GROUPING, seed=1):
  87. """Divide clients into groups with different strategies.
  88. Args:
  89. clients (list[:obj:`BaseClient`]): A list of clients.
  90. world_size (int): The number of processes, it represent the number of groups here.
  91. default_time (float, optional): The default training time for not profiled clients.
  92. strategy (str, optional): Strategy of grouping, options: random, greedy, worst.
  93. When no strategy is applied, each client is a group.
  94. seed (int, optional): Random seed.
  95. Returns:
  96. list[list[:obj:`BaseClient`]]: Groups of clients, each group is a sub-list.
  97. """
  98. np.random.seed(seed)
  99. if strategy == RANDOMIZE_GROUPING:
  100. return randomize_grouping(clients, world_size)
  101. elif strategy == GREEDY_GROUPING:
  102. return greedy_grouping(clients, world_size, default_time)
  103. elif strategy == SLOWEST_GROUPING:
  104. return slowest_grouping(clients, world_size)
  105. else:
  106. # default, no strategy applied
  107. return [[client] for client in clients]
  108. def randomize_grouping(clients, world_size):
  109. """"Randomly divide clients into groups.
  110. Args:
  111. clients (list[:obj:`BaseClient`]): A list of clients.
  112. world_size (int): The number of processes, it represent the number of groups here.
  113. Returns:
  114. list[list[:obj:`BaseClient`]]: Groups of clients, each group is a sub-list.
  115. """
  116. num_of_clients = len(clients)
  117. np.random.shuffle(clients)
  118. data_per_client = num_of_clients // world_size
  119. large_group_num = num_of_clients - world_size * data_per_client
  120. small_group_num = world_size - large_group_num
  121. grouped_clients = []
  122. for i in range(small_group_num):
  123. base_index = data_per_client * i
  124. grouped_clients.append(clients[base_index: base_index + data_per_client])
  125. small_size = data_per_client * small_group_num
  126. data_per_client += 1
  127. for i in range(large_group_num):
  128. base_index = small_size + data_per_client * i
  129. grouped_clients.append(clients[base_index: base_index + data_per_client])
  130. return grouped_clients
  131. def greedy_grouping(clients, world_size, default_time):
  132. """"Greedily allocate the clients with longest training time to the most available device.
  133. Args:
  134. clients (list[:obj:`BaseClient`]): A list of clients.
  135. world_size (int): The number of processes, it represent the number of groups here.
  136. default_time (float, optional): The default training time for not profiled clients.
  137. Returns:
  138. list[list[:obj:`BaseClient`]]: Groups of clients, each group is a sub-list.
  139. """
  140. round_time_estimation = [[i, c.round_time] if c.round_time != 0
  141. else [i, default_time] for i, c in enumerate(clients)]
  142. round_time_estimation = sorted(round_time_estimation, reverse=True, key=lambda tup: (tup[1], tup[0]))
  143. top_world_size = round_time_estimation[:world_size]
  144. groups = [[clients[index]] for (index, time) in top_world_size]
  145. time_sum = [time for (index, time) in top_world_size]
  146. for i in round_time_estimation[world_size:]:
  147. min_index = np.argmin(time_sum)
  148. groups[min_index].append(clients[i[0]])
  149. time_sum[min_index] += i[1]
  150. return groups
  151. def slowest_grouping(clients, world_size):
  152. """"Allocate the clients with longest training time to the most busy device.
  153. Only for experiment, not practical in use.
  154. Args:
  155. clients (list[:obj:`BaseClient`]): A list of clients.
  156. world_size (int): The number of processes, it represent the number of groups here.
  157. Returns:
  158. list[list[:obj:`BaseClient`]]: Groups of clients, each group is a sub-list.
  159. """
  160. num_of_clients = len(clients)
  161. clients = sorted(clients, key=lambda tup: (tup.round_time, tup.cid))
  162. data_per_client = num_of_clients // world_size
  163. large_group_num = num_of_clients - world_size * data_per_client
  164. small_group_num = world_size - large_group_num
  165. grouped_clients = []
  166. for i in range(small_group_num):
  167. base_index = data_per_client * i
  168. grouped_clients.append(clients[base_index: base_index + data_per_client])
  169. small_size = data_per_client * small_group_num
  170. data_per_client += 1
  171. for i in range(large_group_num):
  172. base_index = small_size + data_per_client * i
  173. grouped_clients.append(clients[base_index: base_index + data_per_client])
  174. return grouped_clients
  175. def dist_init(backend, init_method, world_size, rank, local_rank):
  176. """Initialize PyTorch distribute.
  177. Args:
  178. backend (str or Backend): Distributed backend to use, e.g., `nccl`, `gloo`.
  179. init_method (str, optional): URL specifying how to initialize the process group.
  180. world_size (int, optional): Number of processes participating in the job.
  181. rank (int, optional): Rank of the current process.
  182. local rank (int, optional): Local rank of the current process.
  183. Returns:
  184. int: Rank of current process.
  185. int: Total number of processes.
  186. """
  187. dist.init_process_group(backend, init_method=init_method, rank=rank, world_size=world_size)
  188. assert dist.is_initialized()
  189. return rank, world_size
  190. def get_device(gpu, world_size, local_rank):
  191. """Obtain the device by checking the number of GPUs and distributed settings.
  192. Args:
  193. gpu (int): The number of requested gpu.
  194. world_size (int): The number of processes.
  195. local_rank (int): The local rank of the current process.
  196. Returns:
  197. str: Device to be used in PyTorch like `tensor.to(device)`.
  198. """
  199. if gpu > world_size:
  200. logger.error("Available gpu: {}, requested gpu: {}".format(world_size, gpu))
  201. raise ValueError("available number of gpu are less than requested")
  202. # TODO: think of a better way to handle this, maybe just use one config param instead of two.
  203. assert gpu == world_size
  204. n = torch.cuda.device_count()
  205. device_ids = list(range(n))
  206. return device_ids[local_rank]