123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- import argparse
- import copy
- import logging
- import time
- import torch
- from easyfl.client.service import ClientService
- from easyfl.communication import grpc_wrapper
- from easyfl.distributed.distributed import CPU
- from easyfl.pb import common_pb2 as common_pb
- from easyfl.pb import server_service_pb2 as server_pb
- from easyfl.protocol import codec
- from easyfl.tracking import metric
- from easyfl.tracking.client import init_tracking
- from easyfl.tracking.evaluation import model_size
- logger = logging.getLogger(__name__)
- def create_argument_parser():
- """Create argument parser with arguments/configurations for starting remote client service.
- Returns:
- argparse.ArgumentParser: Parser with client service arguments.
- """
- parser = argparse.ArgumentParser(description='Federated Client')
- parser.add_argument('--local-port',
- type=int,
- default=23000,
- help='Listen port of the client')
- parser.add_argument('--server-addr',
- type=str,
- default="localhost:22999",
- help='Address of server in [IP]:[PORT] format')
- parser.add_argument('--tracker-addr',
- type=str,
- default="localhost:12666",
- help='Address of tracking service in [IP]:[PORT] format')
- parser.add_argument('--is-remote',
- type=bool,
- default=False,
- help='Whether start as a remote client.')
- return parser
- class BaseClient(object):
- """Default implementation of federated learning client.
- Args:
- cid (str): Client id.
- conf (omegaconf.dictconfig.DictConfig): Client configurations.
- train_data (:obj:`FederatedDataset`): Training dataset.
- test_data (:obj:`FederatedDataset`): Test dataset.
- device (str): Hardware device for training, cpu or cuda devices.
- sleep_time (float): Duration of on hold after training to simulate stragglers.
- is_remote (bool): Whether start remote training.
- local_port (int): Port of remote client service.
- server_addr (str): Remote server service grpc address.
- tracker_addr (str): Remote tracking service grpc address.
- Override the class and functions to implement customized client.
- Example:
- >>> from easyfl.client import BaseClient
- >>> class CustomizedClient(BaseClient):
- >>> def __init__(self, cid, conf, train_data, test_data, device, **kwargs):
- >>> super(CustomizedClient, self).__init__(cid, conf, train_data, test_data, device, **kwargs)
- >>> pass # more initialization of attributes.
- >>>
- >>> def train(self, conf, device=CPU):
- >>> # Implement customized client training method, which overwrites the default training method.
- >>> pass
- """
- def __init__(self,
- cid,
- conf,
- train_data,
- test_data,
- device,
- sleep_time=0,
- is_remote=False,
- local_port=23000,
- server_addr="localhost:22999",
- tracker_addr="localhost:12666"):
- self.cid = cid
- self.conf = conf
- self.train_data = train_data
- self.train_loader = None
- self.test_data = test_data
- self.test_loader = None
- self.device = device
- self.round_time = 0
- self.train_time = 0
- self.test_time = 0
- self.train_accuracy = []
- self.train_loss = []
- self.test_accuracy = 0
- self.test_loss = 0
- self.profiled = False
- self._sleep_time = sleep_time
- self.compressed_model = None
- self.model = None
- self._upload_holder = server_pb.UploadContent()
- self.is_remote = is_remote
- self.local_port = local_port
- self._server_addr = server_addr
- self._tracker_addr = tracker_addr
- self._server_stub = None
- self._tracker = None
- self._is_train = True
- if conf.track:
- self._tracker = init_tracking(init_store=False)
- def run_train(self, model, conf):
- """Conduct training on clients.
- Args:
- model (nn.Module): Model to train.
- conf (omegaconf.dictconfig.DictConfig): Client configurations.
- Returns:
- :obj:`UploadRequest`: Training contents. Unify the interface for both local and remote operations.
- """
- self.conf = conf
- if conf.track:
- self._tracker.set_client_context(conf.task_id, conf.round_id, self.cid)
- self._is_train = True
- self.download(model)
- self.track(metric.TRAIN_DOWNLOAD_SIZE, model_size(model))
- self.decompression()
- self.pre_train()
- self.train(conf, self.device)
- self.post_train()
- self.track(metric.TRAIN_ACCURACY, self.train_accuracy)
- self.track(metric.TRAIN_LOSS, self.train_loss)
- self.track(metric.TRAIN_TIME, self.train_time)
- if conf.local_test:
- self.test_local()
- self.compression()
- self.track(metric.TRAIN_UPLOAD_SIZE, model_size(self.compressed_model))
- self.encryption()
- return self.upload()
- def run_test(self, model, conf):
- """Conduct testing on clients.
- Args:
- model (nn.Module): Model to test.
- conf (omegaconf.dictconfig.DictConfig): Client configurations.
- Returns:
- :obj:`UploadRequest`: Testing contents. Unify the interface for both local and remote operations.
- """
- self.conf = conf
- if conf.track:
- reset = not self._is_train
- self._tracker.set_client_context(conf.task_id, conf.round_id, self.cid, reset_client=reset)
- self._is_train = False
- self.download(model)
- self.track(metric.TEST_DOWNLOAD_SIZE, model_size(model))
- self.decompression()
- self.pre_test()
- self.test(conf, self.device)
- self.post_test()
- self.track(metric.TEST_ACCURACY, float(self.test_accuracy))
- self.track(metric.TEST_LOSS, float(self.test_loss))
- self.track(metric.TEST_TIME, self.test_time)
- return self.upload()
- def download(self, model):
- """Download model from the server.
- Args:
- model (nn.Module): Global model distributed from the server.
- """
- if self.compressed_model:
- self.compressed_model.load_state_dict(model.state_dict())
- else:
- self.compressed_model = copy.deepcopy(model)
- def decompression(self):
- """Decompressed model. It can be further implemented when the model is compressed in the server."""
- self.model = self.compressed_model
- def pre_train(self):
- """Preprocessing before training."""
- pass
- def train(self, conf, device=CPU):
- """Execute client training.
- Args:
- conf (omegaconf.dictconfig.DictConfig): Client configurations.
- device (str): Hardware device for training, cpu or cuda devices.
- """
- start_time = time.time()
- loss_fn, optimizer = self.pretrain_setup(conf, device)
- self.train_loss = []
- for i in range(conf.local_epoch):
- batch_loss = []
- for batched_x, batched_y in self.train_loader:
- x, y = batched_x.to(device), batched_y.to(device)
- optimizer.zero_grad()
- out = self.model(x)
- loss = loss_fn(out, y)
- loss.backward()
- optimizer.step()
- batch_loss.append(loss.item())
- current_epoch_loss = sum(batch_loss) / len(batch_loss)
- self.train_loss.append(float(current_epoch_loss))
- logger.debug("Client {}, local epoch: {}, loss: {}".format(self.cid, i, current_epoch_loss))
- self.train_time = time.time() - start_time
- logger.debug("Client {}, Train Time: {}".format(self.cid, self.train_time))
- def post_train(self):
- """Postprocessing after training."""
- pass
- def pretrain_setup(self, conf, device):
- """Setup loss function and optimizer before training."""
- self.simulate_straggler()
- self.model.train()
- self.model.to(device)
- loss_fn = self.load_loss_fn(conf)
- optimizer = self.load_optimizer(conf)
- if self.train_loader is None:
- self.train_loader = self.load_loader(conf)
- return loss_fn, optimizer
- def load_loss_fn(self, conf):
- return torch.nn.CrossEntropyLoss()
- def load_optimizer(self, conf):
- """Load training optimizer. Implemented Adam and SGD."""
- if conf.optimizer.type == "Adam":
- optimizer = torch.optim.Adam(self.model.parameters(), lr=conf.optimizer.lr)
- else:
- optimizer = torch.optim.SGD(self.model.parameters(),
- lr=conf.optimizer.lr,
- momentum=conf.optimizer.momentum,
- weight_decay=conf.optimizer.weight_decay)
- return optimizer
- def load_loader(self, conf):
- """Load the training data loader.
- Args:
- conf (omegaconf.dictconfig.DictConfig): Client configurations.
- Returns:
- torch.utils.data.DataLoader: Data loader.
- """
- return self.train_data.loader(conf.batch_size, self.cid, shuffle=True, seed=conf.seed)
- def test_local(self):
- """Test client local model after training."""
- pass
- def pre_test(self):
- """Preprocessing before testing."""
- pass
- def test(self, conf, device=CPU):
- """Execute client testing.
- Args:
- conf (omegaconf.dictconfig.DictConfig): Client configurations.
- device (str): Hardware device for training, cpu or cuda devices.
- """
- begin_test_time = time.time()
- self.model.eval()
- self.model.to(device)
- loss_fn = self.load_loss_fn(conf)
- if self.test_loader is None:
- self.test_loader = self.test_data.loader(conf.test_batch_size, self.cid, shuffle=False, seed=conf.seed)
- self.test_loss = 0
- correct = 0
- with torch.no_grad():
- for batched_x, batched_y in self.test_loader:
- x = batched_x.to(device)
- y = batched_y.to(device)
- log_probs = self.model(x)
- loss = loss_fn(log_probs, y)
- _, y_pred = torch.max(log_probs, -1)
- correct += y_pred.eq(y.data.view_as(y_pred)).long().cpu().sum()
- self.test_loss += loss.item()
- test_size = self.test_data.size(self.cid)
- self.test_loss /= test_size
- self.test_accuracy = 100.0 * float(correct) / test_size
- logger.debug('Client {}, testing -- Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
- self.cid, self.test_loss, correct, test_size, self.test_accuracy))
- self.test_time = time.time() - begin_test_time
- self.model = self.model.cpu()
- def post_test(self):
- """Postprocessing after testing."""
- pass
- def encryption(self):
- """Encrypt the client local model."""
- pass
- def compression(self):
- """Compress the client local model after training and before uploading to the server."""
- self.compressed_model = self.model
- def upload(self):
- """Upload the messages from client to the server.
- Returns:
- :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations.
- Only applicable for local training as remote training upload through a gRPC request.
- """
- request = self.construct_upload_request()
- if not self.is_remote:
- self.post_upload()
- return request
- self.upload_remotely(request)
- self.post_upload()
- def post_upload(self):
- """Postprocessing after uploading training/testing results."""
- pass
- def construct_upload_request(self):
- """Construct client upload request for training updates and testing results.
- Returns:
- :obj:`UploadRequest`: The upload request defined in protobuf to unify local and remote operations.
- """
- data = codec.marshal(server_pb.Performance(accuracy=self.test_accuracy, loss=self.test_loss))
- typ = common_pb.DATA_TYPE_PERFORMANCE
- try:
- if self._is_train:
- data = codec.marshal(copy.deepcopy(self.compressed_model))
- typ = common_pb.DATA_TYPE_PARAMS
- data_size = self.train_data.size(self.cid)
- else:
- data_size = 1 if not self.test_data else self.test_data.size(self.cid)
- except KeyError:
- data_size = 1
- m = self._tracker.get_client_metric().to_proto() if self._tracker else common_pb.ClientMetric()
- return server_pb.UploadRequest(
- task_id=self.conf.task_id,
- round_id=self.conf.round_id,
- client_id=self.cid,
- content=server_pb.UploadContent(
- data=data,
- type=typ,
- data_size=data_size,
- metric=m,
- ),
- )
- def upload_remotely(self, request):
- """Send upload request to remote server via gRPC.
- Args:
- request (:obj:`UploadRequest`): Upload request.
- """
- start_time = time.time()
- self.connect_to_server()
- resp = self._server_stub.Upload(request)
- upload_time = time.time() - start_time
- m = metric.TRAIN_UPLOAD_TIME if self._is_train else metric.TEST_UPLOAD_TIME
- self.track(m, upload_time)
- logger.info("client upload time: {}s".format(upload_time))
- if resp.status.code == common_pb.SC_OK:
- logger.info("Uploaded remotely to the server successfully\n")
- else:
- logger.error("Failed to upload, code: {}, message: {}\n".format(resp.status.code, resp.status.message))
- def start_service(self):
- """Start client service."""
- if self.is_remote:
- grpc_wrapper.start_service(grpc_wrapper.TYPE_CLIENT, ClientService(self), self.local_port)
- def connect_to_server(self):
- """Establish connection between the client and the server."""
- if self.is_remote and self._server_stub is None:
- self._server_stub = grpc_wrapper.init_stub(grpc_wrapper.TYPE_SERVER, self._server_addr)
- logger.info("Successfully connected to gRPC server {}".format(self._server_addr))
- def operate(self, model, conf, index, is_train=True):
- """A wrapper over operations (training/testing) on clients.
- Args:
- model (nn.Module): Model for operations.
- conf (omegaconf.dictconfig.DictConfig): Client configurations.
- index (int): Client index in the client list, for retrieving data. TODO: improvement.
- is_train (bool): The flag to indicate whether the operation is training, otherwise testing.
- """
- try:
- self.cid = self.train_data.users[index]
- except IndexError:
- logger.error("Data index exceed the available data, abort training")
- return
- if self.conf.track and self._tracker is None:
- self._tracker = init_tracking(init_store=False)
- if is_train:
- logger.info("Train on data index {}, client: {}".format(index, self.cid))
- self.run_train(model, conf)
- else:
- logger.info("Test on data index {}, client: {}".format(index, self.cid))
- self.run_test(model, conf)
- def track(self, metric_name, value):
- """Track a metric.
- Args:
- metric_name (str): The name of the metric.
- value (str|int|float|bool|dict|list): The value of the metric.
- """
- if not self.conf.track or self._tracker is None:
- logger.debug("Tracker not available, Tracking not supported")
- return
- self._tracker.track_client(metric_name, value)
- def save_metrics(self):
- """Save client metrics to database."""
- if self._tracker is None:
- logger.debug("Tracker not available, no saving")
- return
- self._tracker.save_client()
- def simulate_straggler(self):
- """Simulate straggler effect of system heterogeneity."""
- if self._sleep_time > 0:
- time.sleep(self._sleep_time)