123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- import torch
- import torch.nn as nn
- import numpy as np
- import copy
- from flcore.clients.clientbase import Client
- from torch.utils.data import DataLoader
- from utils.data_utils import read_client_data
- class clientFomo(Client):
- def __init__(self, args, id, train_samples, test_samples, **kwargs):
- super().__init__(args, id, train_samples, test_samples, **kwargs)
-
- self.num_clients = args.num_clients
- self.old_model = copy.deepcopy(self.model)
- self.received_ids = []
- self.received_models = []
- self.weight_vector = torch.zeros(self.num_clients, device=self.device)
- self.criterion = nn.CrossEntropyLoss()
- self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate)
- self.val_ratio = 0.2
- self.train_samples = self.train_samples * (1-self.val_ratio)
- def train(self):
- trainloader, val_loader = self.load_train_data()
- self.aggregate_parameters(val_loader)
- self.clone_model(self.model, self.old_model)
- # self.model.to(self.device)
- self.model.train()
-
- max_local_steps = self.local_steps
- for step in range(max_local_steps):
- for x, y in trainloader:
- if type(x) == type([]):
- x[0] = x[0].to(self.device)
- else:
- x = x.to(self.device)
- y = y.to(self.device)
- self.optimizer.zero_grad()
- output = self.model(x)
- loss = self.criterion(output, y)
- loss.backward()
- self.optimizer.step()
- # self.model.cpu()
- def standard_train(self):
- trainloader, val_loader = self.load_train_data()
- self.model.train()
- optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
- # 1 epoch
- for i, (x, y) in enumerate(trainloader):
- if type(x) == type([]):
- x[0] = x[0].to(self.device)
- else:
- x = x.to(self.device)
- y = y.to(self.device)
- optimizer.zero_grad()
- output = self.model(x)
- loss = self.criterion(output, y)
- loss.backward()
- optimizer.step()
- def load_train_data(self, batch_size=None):
- if batch_size == None:
- batch_size = self.batch_size
- train_data = read_client_data(self.dataset, self.id, is_train=True)
- val_idx = -int(self.val_ratio*len(train_data))
- val_data = train_data[val_idx:]
- train_data = train_data[:val_idx]
- trainloader = DataLoader(train_data, self.batch_size, drop_last=True, shuffle=True)
- val_loader = DataLoader(val_data, self.batch_size, drop_last=self.has_BatchNorm, shuffle=True)
- return trainloader, val_loader
-
- def receive_models(self, ids, models):
- self.received_ids = ids
- self.received_models = models
- def weight_cal(self, val_loader):
- weight_list = []
- L = self.recalculate_loss(self.old_model, val_loader)
- for received_model in self.received_models:
- params_dif = []
- for param_n, param_i in zip(received_model.parameters(), self.old_model.parameters()):
- params_dif.append((param_n - param_i).view(-1))
- params_dif = torch.cat(params_dif)
- d = L - self.recalculate_loss(received_model, val_loader)
- if d > 0:
- weight_list.append((d / (torch.norm(params_dif) + 1e-5)).item())
- else:
- weight_list.append(0.0)
- if len(weight_list) != 0:
- weight_list = np.array(weight_list)
- weight_list /= (np.sum(weight_list) + 1e-10)
- self.weight_vector_update(weight_list)
- return torch.tensor(weight_list)
-
- def weight_vector_update(self, weight_list):
- self.weight_vector = np.zeros(self.num_clients)
- for w, id in zip(weight_list, self.received_ids):
- self.weight_vector[id] += w.item()
- self.weight_vector = torch.tensor(self.weight_vector).to(self.device)
- def recalculate_loss(self, new_model, val_loader):
- L = 0
- for x, y in val_loader:
- if type(x) == type([]):
- x[0] = x[0].to(self.device)
- else:
- x = x.to(self.device)
- y = y.to(self.device)
- output = new_model(x)
- loss = self.criterion(output, y)
- L += (loss * y.shape[0]).item()
- return L / len(val_loader.dataset)
- def add_parameters(self, w, received_model):
- for param, received_param in zip(self.model.parameters(), received_model.parameters()):
- param.data += received_param.data.clone() * w
-
- def aggregate_parameters(self, val_loader):
- weights = self.weight_cal(val_loader)
- if len(weights) > 0 and sum(weights) > 0.0:
- for param in self.model.parameters():
- param.data.zero_()
- for w, received_model in zip(weights, self.received_models):
- self.add_parameters(w, received_model)
- def weight_scale(self, weights):
- weights = torch.maximum(weights, torch.tensor(0))
- w_sum = torch.sum(weights)
- if w_sum > 0:
- weights = [w/w_sum for w in weights]
- return torch.tensor(weights)
- else:
- return torch.tensor([])
|