clientbabu.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. from flcore.clients.clientbase import Client
  6. class clientBABU(Client):
  7. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  8. super().__init__(args, id, train_samples, test_samples, **kwargs)
  9. self.criterion = nn.CrossEntropyLoss()
  10. self.fine_tuning_steps = args.fine_tuning_steps
  11. self.alpha = args.alpha # fine-tuning's learning rate
  12. for param in self.model.predictor.parameters():
  13. param.requires_grad = False
  14. def train_one_iter(self, x, y, optimizer):
  15. optimizer.zero_grad()
  16. output = self.model(x)
  17. loss = self.criterion(output, y)
  18. loss.backward()
  19. optimizer.step()
  20. def get_training_optimizer(self, **kwargs):
  21. return torch.optim.SGD(self.model.base.parameters(), lr=self.learning_rate, momentum=0.9)
  22. def get_fine_tuning_optimizer(self, **kwargs):
  23. return torch.optim.SGD(self.model.parameters(), lr=self.alpha, momentum=0.9)
  24. def prepare_training(self, **kwargs):
  25. pass
  26. def prepare_fine_tuning(self, **kwargs):
  27. pass
  28. def train(self):
  29. trainloader = self.load_train_data()
  30. # self.model.to(self.device)
  31. self.model.train()
  32. optimizer = self.get_training_optimizer()
  33. self.prepare_training() # prepare_training after getting optimizer
  34. max_local_steps = self.local_steps
  35. for step in range(max_local_steps):
  36. for i, (x, y) in enumerate(trainloader):
  37. if type(x) == type([]):
  38. x[0] = x[0].to(self.device)
  39. else:
  40. x = x.to(self.device)
  41. y = y.to(self.device)
  42. self.train_one_iter(x, y, optimizer)
  43. # self.model.cpu()
  44. def set_parameters(self, base):
  45. for new_param, old_param in zip(base.parameters(), self.model.base.parameters()):
  46. old_param.data = new_param.data.clone()
  47. def set_fine_tune_parameters(self, model):
  48. for new_param, old_param in zip(model.parameters(), self.model.parameters()):
  49. old_param.data = new_param.data.clone()
  50. def fine_tune(self, which_module=['base', 'predictor']):
  51. trainloader = self.load_train_data()
  52. self.model.train()
  53. self.prepare_fine_tuning() # prepare_fine_tuning before getting optimizer
  54. optimizer = self.get_fine_tuning_optimizer()
  55. if 'predictor' in which_module:
  56. for param in self.model.predictor.parameters():
  57. param.requires_grad = True
  58. if 'base' not in which_module:
  59. for param in self.model.predictor.parameters():
  60. param.requires_grad = False
  61. for step in range(self.fine_tuning_steps):
  62. for i, (x, y) in enumerate(trainloader):
  63. if type(x) == type([]):
  64. x[0] = x[0].to(self.device)
  65. else:
  66. x = x.to(self.device)
  67. y = y.to(self.device)
  68. self.train_one_iter(x, y, optimizer)