clientrod.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. from flcore.clients.clientbase import Client
  6. import torch.nn.functional as F
  7. class RodEvalModel(nn.Module):
  8. def __init__(self, glob_m, pers_pred):
  9. super(RodEvalModel, self).__init__()
  10. self.glob_m = glob_m
  11. self.pers_pred = pers_pred
  12. def forward(self, x):
  13. rep = self.glob_m.base(x)
  14. out = self.glob_m.predictor(rep)
  15. out += self.pers_pred(rep)
  16. return out
  17. class clientRoD(Client):
  18. def __init__(self, args, id, train_samples, test_samples, **kwargs):
  19. super().__init__(args, id, train_samples, test_samples, **kwargs)
  20. self.criterion = nn.CrossEntropyLoss()
  21. self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
  22. self.pred = copy.deepcopy(self.model.predictor)
  23. self.opt_pred = torch.optim.SGD(self.pred.parameters(), lr=self.learning_rate)
  24. self.sample_per_class = torch.zeros(self.num_classes)
  25. trainloader = self.load_train_data()
  26. for x, y in trainloader:
  27. for yy in y:
  28. self.sample_per_class[yy.item()] += 1
  29. self.sample_per_class = self.sample_per_class / torch.sum(self.sample_per_class)
  30. def train(self):
  31. trainloader = self.load_train_data()
  32. # self.model.to(self.device)
  33. self.model.train()
  34. max_local_steps = self.local_steps
  35. for step in range(max_local_steps):
  36. for i, (x, y) in enumerate(trainloader):
  37. if type(x) == type([]):
  38. x[0] = x[0].to(self.device)
  39. else:
  40. x = x.to(self.device)
  41. y = y.to(self.device)
  42. self.optimizer.zero_grad()
  43. rep = self.model.base(x)
  44. out_g = self.model.predictor(rep)
  45. loss_bsm = balanced_softmax_loss(y, out_g, self.sample_per_class)
  46. loss_bsm.backward()
  47. self.optimizer.step()
  48. self.opt_pred.zero_grad()
  49. out_p = self.pred(rep.detach())
  50. loss = self.criterion(out_g.detach() + out_p, y)
  51. loss.backward()
  52. self.opt_pred.step()
  53. # self.model.cpu()
  54. # comment for testing on new clients
  55. def get_eval_model(self, temp_model=None):
  56. # temp_model is the current round global model (after aggregation)
  57. return RodEvalModel(temp_model, self.pred)
  58. # https://github.com/jiawei-ren/BalancedMetaSoftmax-Classification
  59. def balanced_softmax_loss(labels, logits, sample_per_class, reduction="mean"):
  60. """Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`.
  61. Args:
  62. labels: A int tensor of size [batch].
  63. logits: A float tensor of size [batch, no_of_classes].
  64. sample_per_class: A int tensor of size [no of classes].
  65. reduction: string. One of "none", "mean", "sum"
  66. Returns:
  67. loss: A float tensor. Balanced Softmax Loss.
  68. """
  69. spc = sample_per_class.type_as(logits)
  70. spc = spc.unsqueeze(0).expand(logits.shape[0], -1)
  71. logits = logits + spc.log()
  72. loss = F.cross_entropy(input=logits, target=labels, reduction=reduction)
  73. return loss