123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- import copy
- import torch
- import torch.nn as nn
- import numpy as np
- import os
- import torch.nn.functional as F
- from torch.utils.data import DataLoader
- from sklearn.preprocessing import label_binarize
- from sklearn import metrics
- from utils.data_utils import read_client_data
- class Client(object):
- """
- Base class for clients in federated learning.
- """
- def __init__(self, args, id, train_samples, test_samples, **kwargs):
- self.model = copy.deepcopy(args.model)
- self.dataset = args.dataset
- self.device = args.device
- self.id = id # integer
- self.num_classes = args.num_classes
- self.train_samples = train_samples
- self.test_samples = test_samples
- self.batch_size = args.batch_size
- self.learning_rate = args.local_learning_rate
- self.local_steps = args.local_steps
- # check BatchNorm
- self.has_BatchNorm = False
- for layer in self.model.children():
- if isinstance(layer, nn.BatchNorm2d):
- self.has_BatchNorm = True
- break
-
- self.sample_rate = self.batch_size / self.train_samples
- def load_train_data(self, batch_size=None):
- if batch_size == None:
- batch_size = self.batch_size
- train_data = read_client_data(self.dataset, self.id, is_train=True)
- batch_size = min(batch_size, len(train_data))
- return DataLoader(train_data, batch_size, drop_last=True, shuffle=True)
- def load_test_data(self, batch_size=None):
- if batch_size == None:
- batch_size = self.batch_size
- test_data = read_client_data(self.dataset, self.id, is_train=False)
- batch_size = min(batch_size, len(test_data))
- return DataLoader(test_data, batch_size, drop_last=False, shuffle=True)
-
- def set_parameters(self, model):
- for new_param, old_param in zip(model.parameters(), self.model.parameters()):
- old_param.data = new_param.data.clone()
- def clone_model(self, model, target):
- for param, target_param in zip(model.parameters(), target.parameters()):
- target_param.data = param.data.clone()
- # target_param.grad = param.grad.clone()
- def update_parameters(self, model, new_params):
- for param, new_param in zip(model.parameters(), new_params):
- param.data = new_param.data.clone()
- def get_eval_model(self, temp_model=None):
- model = self.model_per if hasattr(self, "model_per") else self.model
- return model
- def standard_train(self):
- trainloader = self.load_train_data()
- self.model.train()
- for p in self.model.parameters():
- p.requires_grad = True
- optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
- # 1 epoch
- for i, (x, y) in enumerate(trainloader):
- if type(x) == type([]):
- x[0] = x[0].to(self.device)
- else:
- x = x.to(self.device)
- y = y.to(self.device)
- optimizer.zero_grad()
- output = self.model(x)
- loss = self.criterion(output, y)
- loss.backward()
- optimizer.step()
- def test_metrics(self, temp_model=None):
- testloaderfull = self.load_test_data()
- # self.model = self.load_model('model')
- # self.model.to(self.device)
- model = self.get_eval_model(temp_model)
- model.eval()
-
- test_correct = 0
- test_num = 0
- test_loss = 0.0
- y_prob = []
- y_true = []
-
- with torch.no_grad():
- for x, y in testloaderfull:
- if type(x) == type([]):
- x[0] = x[0].to(self.device)
- else:
- x = x.to(self.device)
- y = y.to(self.device)
- output = model(x)
- test_loss += (self.criterion(output, y.long()) * y.shape[0]).item() # sum up batch loss
- test_correct += (torch.sum(torch.argmax(output, dim=1) == y)).item()
- test_num += y.shape[0]
- y_prob.append(output.detach().cpu().numpy())
- y_true.append(label_binarize(y.detach().cpu().numpy(), classes=np.arange(self.num_classes)))
- # self.model.cpu()
- # self.save_model(self.model, 'model')
- y_prob = np.concatenate(y_prob, axis=0)
- y_true = np.concatenate(y_true, axis=0)
- try:
- test_auc = metrics.roc_auc_score(y_true, y_prob, average='micro')
- test_loss /= test_num
- except ValueError:
- test_auc, test_loss = 0.0, 0.0
- test_acc = test_correct / test_num
- return test_acc, test_auc, test_loss, test_num
|