123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918 |
- import argparse
- import concurrent.futures
- import copy
- import logging
- import os
- import threading
- import time
- import numpy as np
- import torch
- import torch.distributed as dist
- from omegaconf import OmegaConf
- from easyfl.communication import grpc_wrapper
- from easyfl.datasets import TEST_IN_SERVER
- from easyfl.distributed import grouping, reduce_models, reduce_models_only_params, \
- reduce_value, reduce_values, reduce_weighted_values, gather_value
- from easyfl.distributed.distributed import CPU, GREEDY_GROUPING
- from easyfl.pb import client_service_pb2 as client_pb
- from easyfl.pb import common_pb2 as common_pb
- from easyfl.protocol import codec
- from easyfl.registry.etcd_client import EtcdClient
- from easyfl.server import strategies
- from easyfl.server.service import ServerService
- from easyfl.tracking import metric
- from easyfl.tracking.client import init_tracking
- from easyfl.utils.float import rounding
- logger = logging.getLogger(__name__)
- # train and test params
- MODEL = "model"
- DATA_SIZE = "data_size"
- ACCURACY = "accuracy"
- LOSS = "loss"
- CLIENT_METRICS = "client_metrics"
- FEDERATED_AVERAGE = "FedAvg"
- EQUAL_AVERAGE = "equal"
- AGGREGATION_CONTENT_ALL = "all"
- AGGREGATION_CONTENT_PARAMS = "parameters"
- def create_argument_parser():
- """Create argument parser with arguments/configurations for starting server service.
- Returns:
- argparse.ArgumentParser: The parser with server service arguments.
- """
- parser = argparse.ArgumentParser(description='Federated Server')
- parser.add_argument('--local-port',
- type=int,
- default=22999,
- help='Listen port of the client')
- parser.add_argument('--tracker-addr',
- type=str,
- default="localhost:12666",
- help='Address of tracking service in [IP]:[PORT] format')
- parser.add_argument('--is-remote',
- type=bool,
- default=False,
- help='Whether start as a remote server.')
- return parser
- class BaseServer(object):
- """Default implementation of federated learning server.
- Args:
- conf (omegaconf.dictconfig.DictConfig): Configurations of EasyFL.
- test_data (:obj:`FederatedDataset`): Test dataset for centralized testing in server, optional.
- val_data (:obj:`FederatedDataset`): Validation dataset for centralized validation in server, optional.
- is_remote (bool): A flag to indicate whether start remote training.
- local_port (int): The port of remote server service.
- Override the class and functions to implement customized server.
- Example:
- >>> from easyfl.server import BaseServer
- >>> class CustomizedServer(BaseServer):
- >>> def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
- >>> super(CustomizedServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
- >>> pass # more initialization of attributes.
- >>>
- >>> def aggregation(self):
- >>> # Implement customized aggregation method, which overwrites the default aggregation method.
- >>> pass
- """
- def __init__(self,
- conf,
- test_data=None,
- val_data=None,
- is_remote=False,
- local_port=22999):
- self.conf = conf
- self.test_data = test_data
- self.val_data = val_data
- self.is_remote = is_remote
- self.local_port = local_port
- self._is_training = False
- self._should_stop = False
- self._current_round = -1
- self._client_uploads = {}
- self._model = None
- self._compressed_model = None
- self._clients = None
- self._etcd = None
- self.selected_clients = []
- self.grouped_clients = []
- self._server_metric = None
- self._round_time = None
- self._begin_train_time = None # training begin time for a round
- self._start_time = None # training start time for a task
- self.client_stubs = {}
- if self.conf.is_distributed:
- self.default_time = self.conf.resource_heterogeneous.initial_default_time
- self._cumulative_times = [] # cumulative training after each test
- self._accuracies = []
- self._condition = threading.Condition()
- self._tracker = None
- self.init_tracker()
- def start(self, model, clients):
- """Start federated learning process, including training and testing.
- Args:
- model (nn.Module): The model to train.
- clients (list[:obj:`BaseClient`]|list[str]): Available clients.
- Clients are actually client grpc addresses when in remote training.
- """
- # Setup
- self._start_time = time.time()
- self._reset()
- self.set_model(model)
- self.set_clients(clients)
- if self._should_track():
- self._tracker.create_task(self.conf.task_id, OmegaConf.to_container(self.conf))
- # Get initial testing accuracies
- if self.conf.server.test_all:
- if self._should_track():
- self._tracker.set_round(self._current_round)
- self.test()
- self.save_tracker()
- while not self.should_stop():
- self._round_time = time.time()
- self._current_round += 1
- self.print_("\n-------- round {} --------".format(self._current_round))
- # Train
- self.pre_train()
- self.train()
- self.post_train()
- # Test
- if self._do_every(self.conf.server.test_every, self._current_round, self.conf.server.rounds):
- self.pre_test()
- self.test()
- self.post_test()
- # Save Model
- self.save_model()
- self.track(metric.ROUND_TIME, time.time() - self._round_time)
- self.save_tracker()
- self.print_("Accuracies: {}".format(rounding(self._accuracies, 4)))
- self.print_("Cumulative training time: {}".format(rounding(self._cumulative_times, 2)))
- def stop(self):
- """Set the flag to indicate training should stop."""
- self._should_stop = True
- def pre_train(self):
- """Preprocessing before training."""
- pass
- def train(self):
- """Training process of federated learning."""
- self.print_("--- start training ---")
- self.selection(self._clients, self.conf.server.clients_per_round)
- self.grouping_for_distributed()
- self.compression()
- begin_train_time = time.time()
- self.distribution_to_train()
- self.aggregation()
- train_time = time.time() - begin_train_time
- self.print_("Server train time: {}".format(train_time))
- self.track(metric.TRAIN_TIME, train_time)
- def post_train(self):
- """Postprocessing after training."""
- pass
- def pre_test(self):
- """Preprocessing before testing."""
- pass
- def test(self):
- """Testing process of federated learning."""
- self.print_("--- start testing ---")
- test_begin_time = time.time()
- test_results = {metric.TEST_ACCURACY: 0, metric.TEST_LOSS: 0, metric.TEST_TIME: 0}
- if self.conf.test_mode == TEST_IN_SERVER:
- if self.is_primary_server():
- test_results = self.test_in_server(self.conf.device)
- else:
- test_results = self.test_in_client()
- test_results[metric.TEST_TIME] = time.time() - test_begin_time
- self.track_test_results(test_results)
- def post_test(self):
- """Postprocessing after testing."""
- pass
- def should_stop(self):
- """Check whether should stop training. Stops the training under two conditions:
- 1. Reach max number of training rounds
- 2. TODO: Accuracy higher than certain amount.
- Returns:
- bool: A flag to indicate whether should stop training.
- """
- if self._should_stop or (self.conf.server.rounds and self._current_round + 1 >= self.conf.server.rounds):
- self._is_training = False
- return True
- return False
- def test_in_client(self):
- """Conduct testing in clients.
- Currently, it supports testing on the selected clients for training.
- TODO: Add optionals to select clients for testing.
- Returns:
- dict: Test metrics, {"test_loss": value, "test_accuracy": value, "test_time": value}.
- """
- self.compression()
- self.distribution_to_test()
- return self.aggregation_test()
- def test_in_server(self, device=CPU):
- """Conduct testing in the server.
- Args:
- device (str): The hardware device to conduct testing, either cpu or cuda devices.
- Returns:
- dict: Test metrics, {"test_loss": value, "test_accuracy": value, "test_time": value}.
- """
- self._model.eval()
- self._model.to(device)
- test_loss = 0
- correct = 0
- loss_fn = torch.nn.CrossEntropyLoss().to(device)
- with torch.no_grad():
- for batched_x, batched_y in self.test_data.loader(self.conf.server.batch_size, seed=self.conf.seed):
- x = batched_x.to(device)
- y = batched_y.to(device)
- log_probs = self._model(x)
- loss = loss_fn(log_probs, y)
- _, y_pred = torch.max(log_probs, -1)
- correct += y_pred.eq(y.data.view_as(y_pred)).long().cpu().sum()
- test_loss += loss.item()
- test_data_size = self.test_data.size()
- test_loss /= test_data_size
- accuracy = 100.00 * correct / test_data_size
- test_results = {
- metric.TEST_ACCURACY: float(accuracy),
- metric.TEST_LOSS: float(test_loss)
- }
- return test_results
- # Client selection
- def selection(self, clients, clients_per_round):
- """Select a fraction of total clients for training.
- Two selection strategies are implemented: 1. random selection; 2. select the first K clients.
- Args:
- clients (list[:obj:`BaseClient`]|list[str]): Available clients.
- clients_per_round (int): Number of clients to participate in training each round.
- Returns:
- (list[:obj:`BaseClient`]|list[str]): The selected clients.
- """
- if clients_per_round > len(clients):
- logger.warning("Available clients for selection are smaller than required clients for each round")
- clients_per_round = min(clients_per_round, len(clients))
- if self.conf.server.random_selection:
- np.random.seed(self._current_round)
- self.selected_clients = np.random.choice(clients, clients_per_round, replace=False)
- else:
- self.selected_clients = clients[:clients_per_round]
- return self.selected_clients
- def grouping_for_distributed(self):
- """Divide the selected clients into groups for distributed training.
- Each group of clients is assigned to conduct training in one GPU. The number of groups = the number of gpus.
- Not in distributed training, selected clients are in the same group.
- In distributed, selected clients are grouped with different strategies: greedy and random.
- """
- if self.conf.is_distributed:
- groups = grouping(self.selected_clients,
- self.conf.distributed.world_size,
- self.default_time,
- self.conf.resource_heterogeneous.grouping_strategy,
- self._current_round)
- # assign a group for each rank to train with current device.
- self.grouped_clients = groups[self.conf.distributed.rank]
- grouping_info = [(c.cid, c.round_time) for c in self.grouped_clients]
- logger.info("Grouping Result for rank {}: {}".format(self.conf.distributed.rank, grouping_info))
- else:
- self.grouped_clients = self.selected_clients
- rank = 0 if len(self.grouped_clients) == len(self.selected_clients) else self.conf.distributed.rank
- def compression(self):
- """Model compression to reduce communication cost."""
- self._compressed_model = self._model
- def distribution_to_train(self):
- """Distribute model and configurations to selected clients to train."""
- if self.is_remote:
- self.distribution_to_train_remotely()
- else:
- self.distribution_to_train_locally()
- # Adaptively update the training time of clients for greedy grouping.
- if self.conf.is_distributed and self.conf.resource_heterogeneous.grouping_strategy == GREEDY_GROUPING:
- self.profile_training_speed()
- self.update_default_time()
- def distribution_to_train_locally(self):
- """Conduct training sequentially for selected clients in the group."""
- uploaded_models = {}
- uploaded_weights = {}
- uploaded_metrics = []
- for client in self.grouped_clients:
- # Update client config before training
- self.conf.client.task_id = self.conf.task_id
- self.conf.client.round_id = self._current_round
- uploaded_request = client.run_train(self._compressed_model, self.conf.client)
- uploaded_content = uploaded_request.content
- model = self.decompression(codec.unmarshal(uploaded_content.data))
- uploaded_models[client.cid] = model
- uploaded_weights[client.cid] = uploaded_content.data_size
- uploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric))
- self.set_client_uploads_train(uploaded_models, uploaded_weights, uploaded_metrics)
- def distribution_to_train_remotely(self):
- """Distribute training requests to remote clients through multiple threads.
- The main thread waits for signal to proceed. The signal can be triggered via notification, as below example.
- Example to trigger signal:
- >>> with self.condition():
- >>> self.notify_all()
- """
- start_time = time.time()
- should_track = self._tracker is not None and self.conf.client.track
- with concurrent.futures.ThreadPoolExecutor() as executor:
- for client in self.grouped_clients:
- request = client_pb.OperateRequest(
- type=client_pb.OP_TYPE_TRAIN,
- model=codec.marshal(self._compressed_model),
- data_index=client.index,
- config=client_pb.OperateConfig(
- batch_size=self.conf.client.batch_size,
- local_epoch=self.conf.client.local_epoch,
- seed=self.conf.seed,
- local_test=self.conf.client.local_test,
- optimizer=client_pb.Optimizer(
- type=self.conf.client.optimizer.type,
- lr=self.conf.client.optimizer.lr,
- momentum=self.conf.client.optimizer.momentum,
- ),
- task_id=self.conf.task_id,
- round_id=self._current_round,
- track=should_track,
- ),
- )
- executor.submit(self._distribution_remotely, client.client_id, request)
- distribute_time = time.time() - start_time
- self.track(metric.TRAIN_DISTRIBUTE_TIME, distribute_time)
- logger.info("Distribute to clients, time: {}".format(distribute_time))
- with self._condition:
- self._condition.wait()
- def distribution_to_test(self):
- """Distribute to conduct testing on clients."""
- if self.is_remote:
- self.distribution_to_test_remotely()
- else:
- self.distribution_to_test_locally()
- def distribution_to_test_locally(self):
- """Conduct testing sequentially for selected testing clients."""
- uploaded_accuracies = []
- uploaded_losses = []
- uploaded_data_sizes = []
- uploaded_metrics = []
- test_clients = self.get_test_clients()
- for client in test_clients:
- # Update client config before testing
- self.conf.client.task_id = self.conf.task_id
- self.conf.client.round_id = self._current_round
- uploaded_request = client.run_test(self._compressed_model, self.conf.client)
- uploaded_content = uploaded_request.content
- performance = codec.unmarshal(uploaded_content.data)
- uploaded_accuracies.append(performance.accuracy)
- uploaded_losses.append(performance.loss)
- uploaded_data_sizes.append(uploaded_content.data_size)
- uploaded_metrics.append(metric.ClientMetric.from_proto(uploaded_content.metric))
- self.set_client_uploads_test(uploaded_accuracies, uploaded_losses, uploaded_data_sizes, uploaded_metrics)
- def distribution_to_test_remotely(self):
- """Distribute testing requests to remote clients through multiple threads.
- The main thread waits for signal to proceed. The signal can be triggered via notification, as below example.
- Example to trigger signal:
- >>> with self.condition():
- >>> self.notify_all()
- """
- start_time = time.time()
- should_track = self._tracker is not None and self.conf.client.track
- test_clients = self.get_test_clients()
- with concurrent.futures.ThreadPoolExecutor() as executor:
- for client in test_clients:
- request = client_pb.OperateRequest(
- type=client_pb.OP_TYPE_TEST,
- model=codec.marshal(self._compressed_model),
- data_index=client.index,
- config=client_pb.OperateConfig(
- batch_size=self.conf.client.batch_size,
- test_batch_size=self.conf.client.test_batch_size,
- seed=self.conf.seed,
- task_id=self.conf.task_id,
- round_id=self._current_round,
- track=should_track,
- )
- )
- executor.submit(self._distribution_remotely, client.client_id, request)
- distribute_time = time.time() - start_time
- self.track(metric.TEST_DISTRIBUTE_TIME, distribute_time)
- logger.info("Distribute to test clients, time: {}".format(distribute_time))
- with self._condition:
- self._condition.wait()
- def get_test_clients(self):
- """Get clients to run testing.
- Returns:
- (list[:obj:`BaseClient`]|list[str]): Clients to test.
- """
- if self.conf.server.test_all:
- if self.conf.is_distributed:
- # Group and assign clients to different hardware devices to test.
- test_clients = grouping(self._clients,
- self.conf.distributed.world_size,
- default_time=self.default_time,
- strategy=self.conf.resource_heterogeneous.grouping_strategy)
- test_clients = test_clients[self.conf.distributed.rank]
- else:
- test_clients = self._clients
- else:
- # For the initial testing, if no clients are selected, test all clients
- test_clients = self.grouped_clients if self.grouped_clients is not None else self._clients
- return test_clients
- def _distribution_remotely(self, cid, request):
- """Distribute request to the assigned client to conduct operations.
- Args:
- cid (str): Client id.
- request (:obj:`OperateRequest`): gRPC request of specific operations.
- """
- resp = self.client_stubs[cid].Operate(request)
- if resp.status.code != common_pb.SC_OK:
- logger.error("Failed to train/test in client {}, error: {}".format(cid, resp.status.message))
- else:
- logger.info("Distribute to train/test remotely successfully, client: {}".format(cid))
- def aggregation_test(self):
- """Aggregate testing results from clients.
- Returns:
- dict: Test metrics, format in {"test_loss": value, "test_accuracy": value}
- """
- accuracies = self._client_uploads[ACCURACY]
- losses = self._client_uploads[LOSS]
- test_sizes = self._client_uploads[DATA_SIZE]
- if self.conf.test_method == "average":
- loss = self._mean_value(losses)
- accuracy = self._mean_value(accuracies)
- elif self.conf.test_method == "weighted":
- loss = self._weighted_value(losses, test_sizes)
- accuracy = self._weighted_value(accuracies, test_sizes)
- else:
- raise ValueError("test_method not supported, please use average or weighted")
- test_results = {
- metric.TEST_ACCURACY: float(accuracy),
- metric.TEST_LOSS: float(loss)
- }
- return test_results
- def _mean_value(self, values):
- if self.conf.is_distributed:
- return reduce_values(values, self.conf.device)
- else:
- return np.mean(values)
- def _weighted_value(self, values, weights):
- if self.conf.is_distributed:
- return reduce_weighted_values(values, weights, self.conf.device)
- else:
- return np.average(values, weights=weights)
- def decompression(self, model):
- """Decompression the models from clients"""
- return model
- def aggregation(self):
- """Aggregate training updates from clients.
- Server aggregates trained models from clients via federated averaging.
- """
- uploaded_content = self.get_client_uploads()
- models = list(uploaded_content[MODEL].values())
- weights = list(uploaded_content[DATA_SIZE].values())
- model = self.aggregate(models, weights)
- self.set_model(model, load_dict=True)
- def aggregate(self, models, weights):
- """Aggregate models uploaded from clients via federated averaging.
- Args:
- models (list[nn.Module]): List of models.
- weights (list[float]): List of weights, corresponding to each model.
- Weights are dataset size of clients by default.
- Returns
- nn.Module: Aggregated model.
- """
- if self.conf.server.aggregation_strategy == EQUAL_AVERAGE:
- weights = [1 for _ in range(len(models))]
- fn_average = strategies.federated_averaging
- fn_sum = strategies.weighted_sum
- fn_reduce = reduce_models
- if self.conf.server.aggregation_content == AGGREGATION_CONTENT_PARAMS:
- fn_average = strategies.federated_averaging_only_params
- fn_sum = strategies.weighted_sum_only_params
- fn_reduce = reduce_models_only_params
- if self.conf.is_distributed:
- dist.barrier()
- model, sample_sum = fn_sum(models, weights)
- fn_reduce(model, torch.tensor(sample_sum).to(self.conf.device))
- else:
- model = fn_average(models, weights)
- return model
- def _reset(self):
- self._current_round = -1
- self._should_stop = False
- self._is_training = True
- def is_training(self):
- """Check whether the server is in training or has stopped training.
- Returns:
- bool: A flag to indicate whether server is in training.
- """
- return self._is_training
- def set_model(self, model, load_dict=False):
- """Update the universal model in the server.
- Args:
- model (nn.Module): New model.
- load_dict (bool): A flag to indicate whether load state dict or copy the model.
- """
- if load_dict:
- self._model.load_state_dict(model.state_dict())
- else:
- self._model = copy.deepcopy(model)
- def set_clients(self, clients):
- self._clients = clients
- def num_of_clients(self):
- return len(self._clients)
- def save_model(self):
- """Save the model in the server."""
- if self._do_every(self.conf.server.save_model_every, self._current_round, self.conf.server.rounds) and \
- self.is_primary_server():
- save_path = self.conf.server.save_model_path
- if save_path == "":
- save_path = os.path.join(os.getcwd(), "saved_models")
- os.makedirs(save_path, exist_ok=True)
- save_path = os.path.join(save_path,
- "{}_global_model_r_{}.pth".format(self.conf.task_id, self._current_round))
- torch.save(self._model.cpu().state_dict(), save_path)
- self.print_("Model saved at {}".format(save_path))
- def set_client_uploads_train(self, models, weights, metrics=None):
- """Set training updates uploaded from clients.
- Args:
- models (dict): A collection of models.
- weights (dict): A collection of weights.
- metrics (dict): Client training metrics.
- """
- self.set_client_uploads(MODEL, models)
- self.set_client_uploads(DATA_SIZE, weights)
- if self._should_gather_metrics():
- metrics = self.gather_client_train_metrics()
- self.set_client_uploads(CLIENT_METRICS, metrics)
- def set_client_uploads_test(self, accuracies, losses, test_sizes, metrics=None):
- """Set testing results uploaded from clients.
- Args:
- accuracies (list[float]): Testing accuracies of clients.
- losses (list[float]): Testing losses of clients.
- test_sizes (list[float]): Test dataset sizes of clients.
- metrics (dict): Client testing metrics.
- """
- self.set_client_uploads(ACCURACY, accuracies)
- self.set_client_uploads(LOSS, losses)
- self.set_client_uploads(DATA_SIZE, test_sizes)
- if self._should_gather_metrics() and CLIENT_METRICS in self._client_uploads:
- train_metrics = self.get_client_uploads()[CLIENT_METRICS]
- metrics = metric.ClientMetric.merge_train_to_test_metrics(train_metrics, metrics)
- self.set_client_uploads(CLIENT_METRICS, metrics)
- def set_client_uploads(self, key, value):
- """A general function to set uploaded content from clients.
- Args:
- key (str): Dictionary key.
- value (*): Uploaded content.
- """
- self._client_uploads[key] = value
- def get_client_uploads(self):
- """Get client uploaded contents.
- Returns:
- dict: A dictionary that contains client uploaded contents.
- """
- return self._client_uploads
- def _do_every(self, every, current_round, rounds):
- return (current_round + 1) % every == 0 or (current_round + 1) == rounds
- def print_(self, content):
- """print only the server is primary server.
- Args:
- content (str): The content to log.
- """
- if self.is_primary_server():
- logger.info(content)
- def is_primary_server(self):
- """Check whether the current process is the primary server.
- In standalone or remote training, the server is primary.
- In distributed training, the server on rank0 is primary.
- Returns:
- bool: A flag to indicate whether current process is the primary server.
- """
- return not self.conf.is_distributed or self.conf.distributed.rank == 0
- # Functions for remote training
- def start_service(self):
- """Start federated learning server GRPC service."""
- if self.is_remote:
- grpc_wrapper.start_service(grpc_wrapper.TYPE_SERVER, ServerService(self), self.local_port)
- logger.info("GRPC server started at :{}".format(self.local_port))
- def connect_remote_clients(self, clients):
- # TODO: This client should be consistent with client started separately.
- for client in clients:
- if client.client_id not in self.client_stubs:
- self.client_stubs[client.client_id] = grpc_wrapper.init_stub(grpc_wrapper.TYPE_CLIENT, client.address)
- logger.info("Successfully connected to gRPC client {}".format(client.address))
- def init_etcd(self, addresses):
- """Initialize etcd as the registry for client registration.
- Args:
- addresses (str): The etcd addresses split by ","
- """
- self._etcd = EtcdClient("server", addresses, "backends")
- def start_remote_training(self, model, clients):
- """Start federated learning in the remote training mode.
- Server establishes gPRC connection with clients that are not connected first before training.
- Args:
- model (nn.Module): The model to train.
- clients (list[str]): Client addresses.
- """
- self.connect_remote_clients(clients)
- self.start(model, clients)
- # Functions for tracking
- def init_tracker(self):
- """Initialize tracking"""
- if self.conf.server.track:
- self._tracker = init_tracking(self.conf.tracking.database, self.conf.tracker_addr)
- def track(self, metric_name, value):
- """Track a metric.
- Args:
- metric_name (str): Name of the metric of a round.
- value (str|int|float|bool|dict|list): Value of the metric.
- """
- if not self._should_track():
- return
- self._tracker.track_round(metric_name, value)
- def track_test_results(self, results):
- """Track test results collected from clients.
- Args:
- results (dict): Test metrics, format in {"test_loss": value, "test_accuracy": value, "test_time": value}
- """
- self._cumulative_times.append(time.time() - self._start_time)
- self._accuracies.append(results[metric.TEST_ACCURACY])
- for metric_name in results:
- self.track(metric_name, results[metric_name])
- self.print_('Test time {:.2f}s, Test loss: {:.2f}, Test accuracy: {:.2f}%'.format(
- results[metric.TEST_TIME], results[metric.TEST_LOSS], results[metric.TEST_ACCURACY]))
- def save_tracker(self):
- """Save metrics in the tracker to database."""
- if self._tracker:
- self.track_communication_cost()
- if self.is_primary_server():
- self._tracker.save_round()
- # In distributed training, each server saves their clients separately.
- self._tracker.save_clients(self._client_uploads[CLIENT_METRICS])
- def track_communication_cost(self):
- """Track communication cost among server and clients.
- Communication cost occurs in `training` and `testing` with downlink and uplink costs.
- """
- train_upload_size = 0
- train_download_size = 0
- test_upload_size = 0
- test_download_size = 0
- for client_metric in self._client_uploads[CLIENT_METRICS]:
- if client_metric.round_id == self._current_round and client_metric.task_id == self.conf.task_id:
- train_upload_size += client_metric.train_upload_size
- train_download_size += client_metric.train_download_size
- test_upload_size += client_metric.test_upload_size
- test_download_size += client_metric.test_download_size
- if self.conf.is_distributed:
- train_upload_size = reduce_value(train_upload_size, self.conf.device).item()
- train_download_size = reduce_value(train_download_size, self.conf.device).item()
- test_upload_size = reduce_value(test_upload_size, self.conf.device).item()
- test_download_size = reduce_value(test_download_size, self.conf.device).item()
- self._tracker.track_round(metric.TRAIN_UPLOAD_SIZE, train_upload_size)
- self._tracker.track_round(metric.TRAIN_DOWNLOAD_SIZE, train_download_size)
- self._tracker.track_round(metric.TEST_UPLOAD_SIZE, test_upload_size)
- self._tracker.track_round(metric.TEST_DOWNLOAD_SIZE, test_download_size)
- def _should_track(self):
- """Check whether server should track metrics.
- Server tracks metrics only when tracking is enabled and it is the primary server.
- Returns:
- bool: A flag indicate whether server should track metrics.
- """
- return self._tracker is not None and self.is_primary_server()
- def _should_gather_metrics(self):
- """Check whether the server should gather metrics from GPUs.
- Gather metrics only when testing all in `distributed` training.
- Testing all resets clients' training metrics, thus,
- server needs to gather train metrics to construct full client metrics.
- Returns:
- bool: A flag indicate whether server should gather metrics.
- """
- return self.conf.is_distributed and self.conf.server.test_all and self._tracker
- def gather_client_train_metrics(self):
- """Gather client train metrics from other ranks for distributed training, when testing all clients (test_all).
- When testing all clients, the trained metrics may be override by the test metrics
- because clients may be placed in different GPUs in training and testing, leading to losses of train metrics.
- So we gather train metrics and set them in test metrics.
- TODO: gather is not progressing. Need fix.
- """
- world_size = self.conf.distributed.world_size
- device = self.conf.device
- uploads = self.get_client_uploads()
- client_id_list = []
- train_accuracy_list = []
- train_loss_list = []
- train_time_list = []
- train_upload_time_list = []
- train_upload_size_list = []
- train_download_size_list = []
- for m in uploads[CLIENT_METRICS]:
- # client_id_list += gather_value(m.client_id, world_size, device).tolist()
- train_accuracy_list += gather_value(m.train_accuracy, world_size, device)
- train_loss_list += gather_value(m.train_loss, world_size, device)
- train_time_list += gather_value(m.train_time, world_size, device)
- train_upload_time_list += gather_value(m.train_upload_time, world_size, device)
- train_upload_size_list += gather_value(m.train_upload_size, world_size, device)
- train_download_size_list += gather_value(m.train_download_size, world_size, device)
- metrics = []
- # Note: Client id may not match with its training stats because all_gather string is not supported.
- client_id_list = [c.cid for c in self.selected_clients]
- for i, client_id in enumerate(client_id_list):
- m = metric.ClientMetric(self.conf.task_id, self._current_round, client_id)
- m.add(metric.TRAIN_ACCURACY, train_accuracy_list[i])
- m.add(metric.TRAIN_LOSS, train_loss_list[i])
- m.add(metric.TRAIN_TIME, train_time_list[i])
- m.add(metric.TRAIN_UPLOAD_TIME, train_upload_time_list[i])
- m.add(metric.TRAIN_UPLOAD_SIZE, train_upload_size_list[i])
- m.add(metric.TRAIN_DOWNLOAD_SIZE, train_download_size_list[i])
- metrics.append(m)
- return metrics
- # Functions for remote training.
- def condition(self):
- return self._condition
- def notify_all(self):
- self._condition.notify_all()
- # Functions for distributed training optimization.
- def profile_training_speed(self):
- """Manage profiling of client training speeds for distributed training optimization."""
- profile_required = []
- for client in self.selected_clients:
- if not client.profiled:
- profile_required.append(client)
- if len(profile_required) > 0:
- original = torch.FloatTensor([c.round_time for c in profile_required]).to(self.conf.device)
- time_update = torch.FloatTensor([c.train_time for c in profile_required]).to(self.conf.device)
- dist.barrier()
- dist.all_reduce(time_update)
- for i in range(len(profile_required)):
- old_round_time = original[i]
- current_round_time = time_update[i]
- if old_round_time == 0 or self._should_update_round_time(old_round_time, current_round_time):
- profile_required[i].round_time = float(current_round_time)
- profile_required[i].train_time = 0
- else:
- profile_required[i].profiled = True
- def update_default_time(self):
- """Update the estimated default training time of clients using actual training time from profiled clients."""
- default_momentum = self.conf.resource_heterogeneous.default_time_momentum
- current_round_average = np.mean([float(c.round_time) for c in self.selected_clients])
- self.default_time = default_momentum * current_round_average + self.default_time * (1 - default_momentum)
- def _should_update_round_time(self, old_round_time, new_round_time, threshold=0.3):
- """Check whether assign a new estimated round time to client or set it to profiled.
- Args:
- old_round_time (float): previous estimated round time.
- new_round_time (float): Currently profiled round time.
- threshold (float): Tolerance threshold of difference between old and new times.
- Returns:
- bool: A flag to indicate whether should update round time.
- """
- if new_round_time < old_round_time:
- return ((old_round_time - new_round_time) / new_round_time) >= threshold
- else:
- return ((new_round_time - old_round_time) / old_round_time) >= threshold
|