import copy import torch import torch.nn as nn import numpy as np import os import torch.nn.functional as F from torch.utils.data import DataLoader from sklearn.preprocessing import label_binarize from sklearn import metrics from utils.data_utils import read_client_data class Client(object): """ Base class for clients in federated learning. """ def __init__(self, args, id, train_samples, test_samples, **kwargs): self.model = copy.deepcopy(args.model) self.dataset = args.dataset self.device = args.device self.id = id # integer self.num_classes = args.num_classes self.train_samples = train_samples self.test_samples = test_samples self.batch_size = args.batch_size self.learning_rate = args.local_learning_rate self.local_steps = args.local_steps # check BatchNorm self.has_BatchNorm = False for layer in self.model.children(): if isinstance(layer, nn.BatchNorm2d): self.has_BatchNorm = True break self.sample_rate = self.batch_size / self.train_samples def load_train_data(self, batch_size=None): if batch_size == None: batch_size = self.batch_size train_data = read_client_data(self.dataset, self.id, is_train=True) batch_size = min(batch_size, len(train_data)) return DataLoader(train_data, batch_size, drop_last=True, shuffle=True) def load_test_data(self, batch_size=None): if batch_size == None: batch_size = self.batch_size test_data = read_client_data(self.dataset, self.id, is_train=False) batch_size = min(batch_size, len(test_data)) return DataLoader(test_data, batch_size, drop_last=False, shuffle=True) def set_parameters(self, model): for new_param, old_param in zip(model.parameters(), self.model.parameters()): old_param.data = new_param.data.clone() def clone_model(self, model, target): for param, target_param in zip(model.parameters(), target.parameters()): target_param.data = param.data.clone() # target_param.grad = param.grad.clone() def update_parameters(self, model, new_params): for param, new_param in zip(model.parameters(), new_params): param.data = new_param.data.clone() def get_eval_model(self, temp_model=None): model = self.model_per if hasattr(self, "model_per") else self.model return model def standard_train(self): trainloader = self.load_train_data() self.model.train() for p in self.model.parameters(): p.requires_grad = True optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9) # 1 epoch for i, (x, y) in enumerate(trainloader): if type(x) == type([]): x[0] = x[0].to(self.device) else: x = x.to(self.device) y = y.to(self.device) optimizer.zero_grad() output = self.model(x) loss = self.criterion(output, y) loss.backward() optimizer.step() def test_metrics(self, temp_model=None): testloaderfull = self.load_test_data() # self.model = self.load_model('model') # self.model.to(self.device) model = self.get_eval_model(temp_model) model.eval() test_correct = 0 test_num = 0 test_loss = 0.0 y_prob = [] y_true = [] with torch.no_grad(): for x, y in testloaderfull: if type(x) == type([]): x[0] = x[0].to(self.device) else: x = x.to(self.device) y = y.to(self.device) output = model(x) test_loss += (self.criterion(output, y.long()) * y.shape[0]).item() # sum up batch loss test_correct += (torch.sum(torch.argmax(output, dim=1) == y)).item() test_num += y.shape[0] y_prob.append(output.detach().cpu().numpy()) y_true.append(label_binarize(y.detach().cpu().numpy(), classes=np.arange(self.num_classes))) # self.model.cpu() # self.save_model(self.model, 'model') y_prob = np.concatenate(y_prob, axis=0) y_true = np.concatenate(y_true, axis=0) try: test_auc = metrics.roc_auc_score(y_true, y_prob, average='micro') test_loss /= test_num except ValueError: test_auc, test_loss = 0.0, 0.0 test_acc = test_correct / test_num return test_acc, test_auc, test_loss, test_num