client.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import gc
  2. import logging
  3. import torch
  4. import torch._utils
  5. from losses import get_losses
  6. from trainer import Trainer, LR_POLY
  7. from easyfl.client.base import BaseClient
  8. from easyfl.distributed.distributed import CPU
  9. logger = logging.getLogger(__name__)
  10. class MASClient(BaseClient):
  11. def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
  12. super(MASClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
  13. self._local_model = None
  14. criteria = self.load_loss_fn(conf)
  15. train_loader = self.load_loader(conf)
  16. self.trainer = Trainer(self.cid, conf, train_loader, self.model, optimizer=None, criteria=criteria, device=device)
  17. def decompression(self):
  18. if self.model is None:
  19. # Initialization at beginning of the task
  20. self.model = self.compressed_model
  21. def train(self, conf, device=CPU):
  22. self.model.to(device)
  23. optimizer = self.load_optimizer(conf)
  24. self.trainer.update(self.model, optimizer, device)
  25. transference = self.trainer.train()
  26. if conf.lookahead == 'y':
  27. logger.info(f"Round {conf.round_id} - Client {self.cid} transference: {transference}")
  28. def load_loss_fn(self, conf):
  29. criteria = get_losses(conf.task_str, conf.rotate_loss, conf.task_weights)
  30. return criteria
  31. def load_loader(self, conf):
  32. train_loader = self.train_data.loader(conf.batch_size,
  33. self.cid,
  34. shuffle=True,
  35. num_workers=conf.num_workers,
  36. seed=conf.seed)
  37. return train_loader
  38. def load_optimizer(self, conf, lr=None):
  39. if conf.optimizer.lr_type == LR_POLY:
  40. lr = conf.optimizer.lr * pow(1 - (conf.round_id / conf.rounds), 0.9)
  41. else:
  42. if self.trainer.lr:
  43. lr = self.trainer.lr
  44. else:
  45. lr = conf.optimizer.lr
  46. optimizer = torch.optim.SGD(self.model.parameters(),
  47. lr=lr,
  48. momentum=conf.optimizer.momentum,
  49. weight_decay=conf.optimizer.weight_decay)
  50. return optimizer
  51. def post_upload(self):
  52. del self.model
  53. del self.compressed_model
  54. self.model = None
  55. self.compressed_model = None
  56. assert self.model is None
  57. assert self.compressed_model is None
  58. gc.collect()