123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405 |
- """Mechanisms for image reconstruction from parameter gradients."""
- import torch
- from collections import defaultdict, OrderedDict
- from .modules import MetaMonkey
- from .metrics import total_variation as TV
- from .metrics import InceptionScore
- from .medianfilt import MedianPool2d
- from copy import deepcopy
- import time
- DEFAULT_CONFIG = dict(signed=True,
- boxed=True,
- cost_fn='sim',
- indices='topk-1',
- norm='none',
- weights='equal',
- lr=0.01,
- optim='adam',
- restarts=128,
- max_iterations=8_000,
- total_variation=0,
- init='randn',
- filter='none',
- lr_decay=False,
- scoring_choice='loss')
- def _label_to_onehot(target, num_classes=100):
- target = torch.unsqueeze(target, 1)
- onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
- onehot_target.scatter_(1, target, 1)
- return onehot_target
- def _validate_config(config):
- for key in DEFAULT_CONFIG.keys():
- if config.get(key) is None:
- config[key] = DEFAULT_CONFIG[key]
- for key in config.keys():
- if DEFAULT_CONFIG.get(key) is None:
- raise ValueError(f'Deprecated key in config dict: {key}!')
- return config
- class GradientReconstructor():
- """Instantiate a reconstruction algorithm."""
- def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images=1):
- """Initialize with algorithm setup."""
- self.config = _validate_config(config)
- self.model = model
- self.setup = dict(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype)
- self.mean_std = mean_std
- self.num_images = num_images
- if self.config['scoring_choice'] == 'inception':
- self.inception = InceptionScore(batch_size=1, setup=self.setup)
- self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
- # self.iDLG = True
- # self.grad_log = defaultdict(dict)
- def reconstruct(self, input_data, labels, input_img, img_shape=(3, 32, 32), dryrun=False, eval=True):
- """Reconstruct image from gradient."""
- start_time = time.time()
- if eval:
- self.model.eval()
-
- stats = defaultdict(list)
- # grad_log = self.grad_log
- x = self._init_images(img_shape)
- scores = torch.zeros(self.config['restarts'])
-
- # if labels is None:
- # if self.num_images == 1 and self.iDLG:
- # # iDLG trick:
- # last_weight_min = torch.argmin(torch.sum(input_data[-2], dim=-1), dim=-1)
- # labels = last_weight_min.detach().reshape((1,)).requires_grad_(False)
- # self.reconstruct_label = False
- # else:
- # # DLG label recovery
- # # However this also improves conditioning for some LBFGS cases
- # self.reconstruct_label = True
- # def loss_fn(pred, labels):
- # labels = torch.nn.functional.softmax(labels, dim=-1)
- # return torch.mean(torch.sum(- labels * torch.nn.functional.log_softmax(pred, dim=-1), 1))
- # self.loss_fn = loss_fn
- assert labels.shape[0] == self.num_images
- self.reconstruct_label = False
- try:
- for trial in range(self.config['restarts']):
- x_trial, labels = self._run_trial(x[trial], input_data, labels, input_img, dryrun=dryrun)
- # Finalize
- scores[trial] = self._score_trial(x_trial, input_data, labels)
- x[trial] = x_trial
- if dryrun:
- break
- except KeyboardInterrupt:
- print('Trial procedure manually interruped.')
- pass
- # Choose optimal result:
- if self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']:
- x_optimal, stats = self._average_trials(x, labels, input_data, stats)
- else:
- print('Choosing optimal result ...')
- scores = scores[torch.isfinite(scores)] # guard against NaN/-Inf scores?
- optimal_index = torch.argmin(scores)
- print(f'Optimal result score: {scores[optimal_index]:2.4f}')
- stats['opt'] = scores[optimal_index].item()
- x_optimal = x[optimal_index]
- print(f'Total time: {time.time()-start_time}.')
- return x_optimal.detach(), stats
- # return x_optimal.detach(), stats, grad_log
- def _init_images(self, img_shape):
- if self.config['init'] == 'randn':
- return torch.randn((self.config['restarts'], self.num_images, *img_shape), **self.setup)
- elif self.config['init'] == 'rand':
- return (torch.rand((self.config['restarts'], self.num_images, *img_shape), **self.setup) - 0.5) * 2
- elif self.config['init'] == 'zeros':
- return torch.zeros((self.config['restarts'], self.num_images, *img_shape), **self.setup)
- else:
- raise ValueError()
- def _run_trial(self, x_trial, input_data, labels, x_input, dryrun=False): # x_input - ground truth
- x_trial.requires_grad = True
- # grad_log = self.grad_log
- if self.reconstruct_label:
- output_test = self.model(x_trial)
- labels = torch.randn(output_test.shape[1]).to(**self.setup).requires_grad_(True)
- if self.config['optim'] == 'adam':
- optimizer = torch.optim.Adam([x_trial, labels], lr=self.config['lr'])
- elif self.config['optim'] == 'sgd': # actually gd
- optimizer = torch.optim.SGD([x_trial, labels], lr=0.01, momentum=0.9, nesterov=True)
- elif self.config['optim'] == 'LBFGS':
- optimizer = torch.optim.LBFGS([x_trial, labels])
- else:
- raise ValueError()
- else:
- if self.config['optim'] == 'adam':
- optimizer = torch.optim.Adam([x_trial], lr=self.config['lr'])
- elif self.config['optim'] == 'sgd': # actually gd
- optimizer = torch.optim.SGD([x_trial], lr=0.01, momentum=0.9, nesterov=True)
- elif self.config['optim'] == 'LBFGS':
- optimizer = torch.optim.LBFGS([x_trial])
- else:
- raise ValueError()
-
- max_iterations = self.config['max_iterations']
- dm, ds = self.mean_std
- if self.config['lr_decay']:
- scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
- milestones=[max_iterations // 2.667, max_iterations // 1.6,
- max_iterations // 1.142], gamma=0.1) # 3/8 5/8 7/8
- try:
- for iteration in range(max_iterations):
- closure = self._gradient_closure(optimizer, x_trial, input_data, labels)
- rec_loss = optimizer.step(closure)
- if self.config['lr_decay']:
- scheduler.step()
- with torch.no_grad():
- # Project into image space
- if self.config['boxed']:
- x_trial.data = torch.max(torch.min(x_trial, (1 - dm) / ds), -dm / ds)
- if iteration + 1 == max_iterations: # or iteration % 100 == 0:
- print(f'It: {iteration + 1}. Rec. loss: {rec_loss.item():2.4f}.')
- if (iteration + 1) % 500 == 0:
- if self.config['filter'] == 'none':
- pass
- elif self.config['filter'] == 'median':
- x_trial.data = MedianPool2d(kernel_size=3, stride=1, padding=1, same=False)(x_trial)
- else:
- raise ValueError()
-
- # grad_log['rec_loss'][iteration] = rec_loss.item()
- # grad_log['input_mse'][iteration] = (x_trial.detach() - x_input).pow(2).mean().item()
-
-
- if dryrun:
- break
- except KeyboardInterrupt:
- print(f'Recovery interrupted manually in iteration {iteration}!')
- pass
- return x_trial.detach(), labels
- def _gradient_closure(self, optimizer, x_trial, input_gradient, label):
- def closure():
- optimizer.zero_grad()
- self.model.zero_grad()
- loss = self.loss_fn(self.model(x_trial), label)
- gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
- rec_loss = reconstruction_costs([gradient], input_gradient,
- cost_fn=self.config['cost_fn'], indices=self.config['indices'],
- norm=self.config['norm'], weights=self.config['weights'])
- if self.config['total_variation'] > 0:
- rec_loss += self.config['total_variation'] * TV(x_trial)
- rec_loss.backward() # second derivative --> difference between gradients
- if self.config['signed']:
- x_trial.grad.sign_()
- return rec_loss
- return closure
- def _score_trial(self, x_trial, input_gradient, label):
- if self.config['scoring_choice'] == 'loss':
- self.model.zero_grad()
- x_trial.grad = None
- loss = self.loss_fn(self.model(x_trial), label)
- gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
- return reconstruction_costs([gradient], input_gradient,
- cost_fn=self.config['cost_fn'], indices=self.config['indices'],
- norm=self.config['norm'], weights=self.config['weights'])
- elif self.config['scoring_choice'] == 'tv':
- return TV(x_trial)
- elif self.config['scoring_choice'] == 'inception':
- # We do not care about diversity here!
- return self.inception(x_trial)
- elif self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']:
- return 0.0
- else:
- raise ValueError()
- def _average_trials(self, x, labels, input_data, stats):
- print(f'Computing a combined result via {self.config["scoring_choice"]} ...')
- if self.config['scoring_choice'] == 'pixelmedian':
- x_optimal, _ = x.median(dim=0, keepdims=False)
- elif self.config['scoring_choice'] == 'pixelmean':
- x_optimal = x.mean(dim=0, keepdims=False)
- self.model.zero_grad()
- if self.reconstruct_label:
- labels = self.model(x_optimal).softmax(dim=1)
- loss = self.loss_fn(self.model(x_optimal), labels)
- gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
- stats['opt'] = reconstruction_costs([gradient], input_data,
- cost_fn=self.config['cost_fn'],
- indices=self.config['indices'],
- norm=self.config['norm'],
- weights=self.config['weights'])
- print(f'Optimal result score: {stats["opt"]:2.4f}')
- return x_optimal, stats
- class FedAvgReconstructor(GradientReconstructor):
- """Reconstruct an image from weights after n gradient descent steps."""
- def __init__(self, model, mean_std=(0.0, 1.0), local_steps=2, local_lr=1e-4,
- config=DEFAULT_CONFIG, num_images=1, use_updates=True):
- """Initialize with model, (mean, std) and config."""
- super().__init__(model, mean_std, config, num_images)
- self.local_steps = local_steps
- self.local_lr = local_lr
- self.use_updates = use_updates
- def _gradient_closure(self, optimizer, x_trial, input_parameters, labels):
- def closure():
- optimizer.zero_grad()
- self.model.zero_grad()
- parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn,
- local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates)
- rec_loss = reconstruction_costs([gradient], input_gradient,
- cost_fn=self.config['cost_fn'], indices=self.config['indices'],
- norm=self.config['norm'], weights=self.config['weights'])
- if self.config['total_variation'] > 0:
- rec_loss += self.config['total_variation'] * TV(x_trial)
- rec_loss.backward()
- if self.config['signed']:
- x_trial.grad.sign_()
- return rec_loss
- return closure
- def _score_trial(self, x_trial, input_parameters, labels):
- if self.config['scoring_choice'] == 'loss':
- self.model.zero_grad()
- parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn,
- local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates)
- return reconstruction_costs([gradient], input_gradient,
- cost_fn=self.config['cost_fn'], indices=self.config['indices'],
- norm=self.config['norm'], weights=self.config['weights'])
- elif self.config['scoring_choice'] == 'tv':
- return TV(x_trial)
- elif self.config['scoring_choice'] == 'inception':
- # We do not care about diversity here!
- return self.inception(x_trial)
- def loss_steps(model, inputs, labels, loss_fn=torch.nn.CrossEntropyLoss(), lr=1e-4, local_steps=4, use_updates=True, batch_size=0):
- """Take a few gradient descent steps to fit the model to the given input."""
- patched_model = MetaMonkey(model)
- if use_updates:
- patched_model_origin = deepcopy(patched_model)
- for i in range(local_steps):
- if batch_size == 0:
- outputs = patched_model(inputs, patched_model.parameters)
- labels_ = labels
- else:
- idx = i % (inputs.shape[0] // batch_size)
- outputs = patched_model(inputs[idx * batch_size:(idx + 1) * batch_size], patched_model.parameters)
- labels_ = labels[idx * batch_size:(idx + 1) * batch_size]
- loss = loss_fn(outputs, labels_).sum()
- grad = torch.autograd.grad(loss, patched_model.parameters.values(),
- retain_graph=True, create_graph=True, only_inputs=True)
- patched_model.parameters = OrderedDict((name, param - lr * grad_part)
- for ((name, param), grad_part)
- in zip(patched_model.parameters.items(), grad))
- if use_updates:
- patched_model.parameters = OrderedDict((name, param - param_origin)
- for ((name, param), (name_origin, param_origin))
- in zip(patched_model.parameters.items(), patched_model_origin.parameters.items()))
- return list(patched_model.parameters.values())
- def reconstruction_costs(gradients, input_gradient, cost_fn='l2', indices='def', norm='none', weights='equal'):
- """Input gradient is given data."""
- if isinstance(indices, list):
- pass
- elif indices == 'def':
- indices = torch.arange(len(input_gradient))
- elif indices == 'batch':
- indices = torch.randperm(len(input_gradient))[:8]
- elif indices == 'topk-1':
- _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 4)
- elif indices == 'top10':
- _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 10)
- elif indices == 'top50':
- _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 50)
- elif indices in ['first', 'first4']:
- indices = torch.arange(0, 4)
- elif indices == 'first5':
- indices = torch.arange(0, 5)
- elif indices == 'first10':
- indices = torch.arange(0, 10)
- elif indices == 'first50':
- indices = torch.arange(0, 50)
- elif indices == 'last5':
- indices = torch.arange(len(input_gradient))[-5:]
- elif indices == 'last10':
- indices = torch.arange(len(input_gradient))[-10:]
- elif indices == 'last50':
- indices = torch.arange(len(input_gradient))[-50:]
- # customise the pnorm choose
- elif indices == 'custom_bottom':
- _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), norm, largest=False)
- elif indices == 'custom_top':
- _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), norm, largest=True)
- else:
- raise ValueError()
- ex = input_gradient[0]
- if weights == 'linear':
- weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient)
- elif weights == 'exp':
- weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device)
- weights = weights.softmax(dim=0)
- weights = weights / weights[0]
- else:
- # [1.,1.,1......]
- weights = input_gradient[0].new_ones(len(input_gradient))
- total_costs = 0
- for trial_gradient in gradients:
- pnorm = [0, 0]
- costs = 0
-
- for i in indices:
- if cost_fn == 'l2':
- costs += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum() * weights[i]
- elif cost_fn == 'l1':
- costs += ((trial_gradient[i] - input_gradient[i]).abs()).sum() * weights[i]
- elif cost_fn == 'max':
- costs += ((trial_gradient[i] - input_gradient[i]).abs()).max() * weights[i]
- elif cost_fn == 'sim':
- costs -= (trial_gradient[i] * input_gradient[i]).sum() * weights[i]
- pnorm[0] += trial_gradient[i].pow(2).sum() * weights[i]
- pnorm[1] += input_gradient[i].pow(2).sum() * weights[i]
-
- elif cost_fn == 'simlocal':
- costs += 1 - torch.nn.functional.cosine_similarity(trial_gradient[i].flatten(),
- input_gradient[i].flatten(),
- 0, 1e-10) * weights[i]
-
- if cost_fn == 'sim':
- costs = 1 + costs / pnorm[0].sqrt() / pnorm[1].sqrt()
- # Accumulate final costs
- total_costs += costs
-
- return total_costs / len(gradients)
|