123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import copy
- import torch
- import torch.nn as nn
- import numpy as np
- from flcore.clients.clientbase import Client
- import torch.nn.functional as F
- class RodEvalModel(nn.Module):
- def __init__(self, glob_m, pers_pred):
- super(RodEvalModel, self).__init__()
- self.glob_m = glob_m
- self.pers_pred = pers_pred
- def forward(self, x):
- rep = self.glob_m.base(x)
- out = self.glob_m.predictor(rep)
- out += self.pers_pred(rep)
- return out
- class clientRoD(Client):
- def __init__(self, args, id, train_samples, test_samples, **kwargs):
- super().__init__(args, id, train_samples, test_samples, **kwargs)
-
- self.criterion = nn.CrossEntropyLoss()
- self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
-
- self.pred = copy.deepcopy(self.model.predictor)
- self.opt_pred = torch.optim.SGD(self.pred.parameters(), lr=self.learning_rate)
- self.sample_per_class = torch.zeros(self.num_classes)
- trainloader = self.load_train_data()
- for x, y in trainloader:
- for yy in y:
- self.sample_per_class[yy.item()] += 1
- self.sample_per_class = self.sample_per_class / torch.sum(self.sample_per_class)
- def train(self):
- trainloader = self.load_train_data()
- # self.model.to(self.device)
- self.model.train()
- max_local_steps = self.local_steps
- for step in range(max_local_steps):
- for i, (x, y) in enumerate(trainloader):
- if type(x) == type([]):
- x[0] = x[0].to(self.device)
- else:
- x = x.to(self.device)
- y = y.to(self.device)
- self.optimizer.zero_grad()
- rep = self.model.base(x)
- out_g = self.model.predictor(rep)
- loss_bsm = balanced_softmax_loss(y, out_g, self.sample_per_class)
- loss_bsm.backward()
- self.optimizer.step()
-
- self.opt_pred.zero_grad()
- out_p = self.pred(rep.detach())
- loss = self.criterion(out_g.detach() + out_p, y)
- loss.backward()
- self.opt_pred.step()
- # self.model.cpu()
- # comment for testing on new clients
- def get_eval_model(self, temp_model=None):
- # temp_model is the current round global model (after aggregation)
- return RodEvalModel(temp_model, self.pred)
- # https://github.com/jiawei-ren/BalancedMetaSoftmax-Classification
- def balanced_softmax_loss(labels, logits, sample_per_class, reduction="mean"):
- """Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`.
- Args:
- labels: A int tensor of size [batch].
- logits: A float tensor of size [batch, no_of_classes].
- sample_per_class: A int tensor of size [no of classes].
- reduction: string. One of "none", "mean", "sum"
- Returns:
- loss: A float tensor. Balanced Softmax Loss.
- """
- spc = sample_per_class.type_as(logits)
- spc = spc.unsqueeze(0).expand(logits.shape[0], -1)
- logits = logits + spc.log()
- loss = F.cross_entropy(input=logits, target=labels, reduction=reduction)
- return loss
|