12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import gc
- import logging
- import torch
- import torch._utils
- from losses import get_losses
- from trainer import Trainer, LR_POLY
- from easyfl.client.base import BaseClient
- from easyfl.distributed.distributed import CPU
- logger = logging.getLogger(__name__)
- class MASClient(BaseClient):
- def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
- super(MASClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
- self._local_model = None
- criteria = self.load_loss_fn(conf)
- train_loader = self.load_loader(conf)
- self.trainer = Trainer(self.cid, conf, train_loader, self.model, optimizer=None, criteria=criteria, device=device)
- def decompression(self):
- if self.model is None:
- # Initialization at beginning of the task
- self.model = self.compressed_model
- def train(self, conf, device=CPU):
- self.model.to(device)
- optimizer = self.load_optimizer(conf)
- self.trainer.update(self.model, optimizer, device)
- transference = self.trainer.train()
- if conf.lookahead == 'y':
- logger.info(f"Round {conf.round_id} - Client {self.cid} transference: {transference}")
- def load_loss_fn(self, conf):
- criteria = get_losses(conf.task_str, conf.rotate_loss, conf.task_weights)
- return criteria
- def load_loader(self, conf):
- train_loader = self.train_data.loader(conf.batch_size,
- self.cid,
- shuffle=True,
- num_workers=conf.num_workers,
- seed=conf.seed)
- return train_loader
- def load_optimizer(self, conf, lr=None):
- if conf.optimizer.lr_type == LR_POLY:
- lr = conf.optimizer.lr * pow(1 - (conf.round_id / conf.rounds), 0.9)
- else:
- if self.trainer.lr:
- lr = self.trainer.lr
- else:
- lr = conf.optimizer.lr
- optimizer = torch.optim.SGD(self.model.parameters(),
- lr=lr,
- momentum=conf.optimizer.momentum,
- weight_decay=conf.optimizer.weight_decay)
- return optimizer
- def post_upload(self):
- del self.model
- del self.compressed_model
- self.model = None
- self.compressed_model = None
- assert self.model is None
- assert self.compressed_model is None
- gc.collect()
|