|
@@ -0,0 +1,405 @@
|
|
|
+"""Mechanisms for image reconstruction from parameter gradients."""
|
|
|
+
|
|
|
+import torch
|
|
|
+from collections import defaultdict, OrderedDict
|
|
|
+from .modules import MetaMonkey
|
|
|
+from .metrics import total_variation as TV
|
|
|
+from .metrics import InceptionScore
|
|
|
+from .medianfilt import MedianPool2d
|
|
|
+from copy import deepcopy
|
|
|
+import time
|
|
|
+
|
|
|
+DEFAULT_CONFIG = dict(signed=True,
|
|
|
+ boxed=True,
|
|
|
+ cost_fn='sim',
|
|
|
+ indices='topk-1',
|
|
|
+ norm='none',
|
|
|
+ weights='equal',
|
|
|
+ lr=0.01,
|
|
|
+ optim='adam',
|
|
|
+ restarts=128,
|
|
|
+ max_iterations=8_000,
|
|
|
+ total_variation=0,
|
|
|
+ init='randn',
|
|
|
+ filter='none',
|
|
|
+ lr_decay=False,
|
|
|
+ scoring_choice='loss')
|
|
|
+
|
|
|
+def _label_to_onehot(target, num_classes=100):
|
|
|
+ target = torch.unsqueeze(target, 1)
|
|
|
+ onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
|
|
|
+ onehot_target.scatter_(1, target, 1)
|
|
|
+ return onehot_target
|
|
|
+
|
|
|
+def _validate_config(config):
|
|
|
+ for key in DEFAULT_CONFIG.keys():
|
|
|
+ if config.get(key) is None:
|
|
|
+ config[key] = DEFAULT_CONFIG[key]
|
|
|
+ for key in config.keys():
|
|
|
+ if DEFAULT_CONFIG.get(key) is None:
|
|
|
+ raise ValueError(f'Deprecated key in config dict: {key}!')
|
|
|
+ return config
|
|
|
+
|
|
|
+
|
|
|
+class GradientReconstructor():
|
|
|
+ """Instantiate a reconstruction algorithm."""
|
|
|
+
|
|
|
+ def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images=1):
|
|
|
+ """Initialize with algorithm setup."""
|
|
|
+ self.config = _validate_config(config)
|
|
|
+ self.model = model
|
|
|
+ self.setup = dict(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype)
|
|
|
+
|
|
|
+ self.mean_std = mean_std
|
|
|
+ self.num_images = num_images
|
|
|
+
|
|
|
+ if self.config['scoring_choice'] == 'inception':
|
|
|
+ self.inception = InceptionScore(batch_size=1, setup=self.setup)
|
|
|
+
|
|
|
+ self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
|
|
|
+# self.iDLG = True
|
|
|
+# self.grad_log = defaultdict(dict)
|
|
|
+
|
|
|
+ def reconstruct(self, input_data, labels, input_img, img_shape=(3, 32, 32), dryrun=False, eval=True):
|
|
|
+ """Reconstruct image from gradient."""
|
|
|
+ start_time = time.time()
|
|
|
+ if eval:
|
|
|
+ self.model.eval()
|
|
|
+
|
|
|
+
|
|
|
+ stats = defaultdict(list)
|
|
|
+# grad_log = self.grad_log
|
|
|
+ x = self._init_images(img_shape)
|
|
|
+ scores = torch.zeros(self.config['restarts'])
|
|
|
+
|
|
|
+# if labels is None:
|
|
|
+# if self.num_images == 1 and self.iDLG:
|
|
|
+# # iDLG trick:
|
|
|
+# last_weight_min = torch.argmin(torch.sum(input_data[-2], dim=-1), dim=-1)
|
|
|
+# labels = last_weight_min.detach().reshape((1,)).requires_grad_(False)
|
|
|
+# self.reconstruct_label = False
|
|
|
+# else:
|
|
|
+# # DLG label recovery
|
|
|
+# # However this also improves conditioning for some LBFGS cases
|
|
|
+# self.reconstruct_label = True
|
|
|
+
|
|
|
+# def loss_fn(pred, labels):
|
|
|
+# labels = torch.nn.functional.softmax(labels, dim=-1)
|
|
|
+# return torch.mean(torch.sum(- labels * torch.nn.functional.log_softmax(pred, dim=-1), 1))
|
|
|
+# self.loss_fn = loss_fn
|
|
|
+
|
|
|
+ assert labels.shape[0] == self.num_images
|
|
|
+ self.reconstruct_label = False
|
|
|
+
|
|
|
+ try:
|
|
|
+ for trial in range(self.config['restarts']):
|
|
|
+ x_trial, labels = self._run_trial(x[trial], input_data, labels, input_img, dryrun=dryrun)
|
|
|
+ # Finalize
|
|
|
+ scores[trial] = self._score_trial(x_trial, input_data, labels)
|
|
|
+ x[trial] = x_trial
|
|
|
+ if dryrun:
|
|
|
+ break
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ print('Trial procedure manually interruped.')
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Choose optimal result:
|
|
|
+ if self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']:
|
|
|
+ x_optimal, stats = self._average_trials(x, labels, input_data, stats)
|
|
|
+ else:
|
|
|
+ print('Choosing optimal result ...')
|
|
|
+ scores = scores[torch.isfinite(scores)] # guard against NaN/-Inf scores?
|
|
|
+ optimal_index = torch.argmin(scores)
|
|
|
+ print(f'Optimal result score: {scores[optimal_index]:2.4f}')
|
|
|
+ stats['opt'] = scores[optimal_index].item()
|
|
|
+ x_optimal = x[optimal_index]
|
|
|
+
|
|
|
+ print(f'Total time: {time.time()-start_time}.')
|
|
|
+ return x_optimal.detach(), stats
|
|
|
+# return x_optimal.detach(), stats, grad_log
|
|
|
+
|
|
|
+ def _init_images(self, img_shape):
|
|
|
+ if self.config['init'] == 'randn':
|
|
|
+ return torch.randn((self.config['restarts'], self.num_images, *img_shape), **self.setup)
|
|
|
+ elif self.config['init'] == 'rand':
|
|
|
+ return (torch.rand((self.config['restarts'], self.num_images, *img_shape), **self.setup) - 0.5) * 2
|
|
|
+ elif self.config['init'] == 'zeros':
|
|
|
+ return torch.zeros((self.config['restarts'], self.num_images, *img_shape), **self.setup)
|
|
|
+ else:
|
|
|
+ raise ValueError()
|
|
|
+
|
|
|
+ def _run_trial(self, x_trial, input_data, labels, x_input, dryrun=False): # x_input - ground truth
|
|
|
+ x_trial.requires_grad = True
|
|
|
+# grad_log = self.grad_log
|
|
|
+ if self.reconstruct_label:
|
|
|
+ output_test = self.model(x_trial)
|
|
|
+ labels = torch.randn(output_test.shape[1]).to(**self.setup).requires_grad_(True)
|
|
|
+
|
|
|
+ if self.config['optim'] == 'adam':
|
|
|
+ optimizer = torch.optim.Adam([x_trial, labels], lr=self.config['lr'])
|
|
|
+ elif self.config['optim'] == 'sgd': # actually gd
|
|
|
+ optimizer = torch.optim.SGD([x_trial, labels], lr=0.01, momentum=0.9, nesterov=True)
|
|
|
+ elif self.config['optim'] == 'LBFGS':
|
|
|
+ optimizer = torch.optim.LBFGS([x_trial, labels])
|
|
|
+ else:
|
|
|
+ raise ValueError()
|
|
|
+ else:
|
|
|
+ if self.config['optim'] == 'adam':
|
|
|
+ optimizer = torch.optim.Adam([x_trial], lr=self.config['lr'])
|
|
|
+ elif self.config['optim'] == 'sgd': # actually gd
|
|
|
+ optimizer = torch.optim.SGD([x_trial], lr=0.01, momentum=0.9, nesterov=True)
|
|
|
+ elif self.config['optim'] == 'LBFGS':
|
|
|
+ optimizer = torch.optim.LBFGS([x_trial])
|
|
|
+ else:
|
|
|
+ raise ValueError()
|
|
|
+
|
|
|
+ max_iterations = self.config['max_iterations']
|
|
|
+ dm, ds = self.mean_std
|
|
|
+ if self.config['lr_decay']:
|
|
|
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
|
|
|
+ milestones=[max_iterations // 2.667, max_iterations // 1.6,
|
|
|
+
|
|
|
+ max_iterations // 1.142], gamma=0.1) # 3/8 5/8 7/8
|
|
|
+ try:
|
|
|
+ for iteration in range(max_iterations):
|
|
|
+ closure = self._gradient_closure(optimizer, x_trial, input_data, labels)
|
|
|
+ rec_loss = optimizer.step(closure)
|
|
|
+
|
|
|
+ if self.config['lr_decay']:
|
|
|
+ scheduler.step()
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ # Project into image space
|
|
|
+ if self.config['boxed']:
|
|
|
+ x_trial.data = torch.max(torch.min(x_trial, (1 - dm) / ds), -dm / ds)
|
|
|
+ if iteration + 1 == max_iterations: # or iteration % 100 == 0:
|
|
|
+ print(f'It: {iteration + 1}. Rec. loss: {rec_loss.item():2.4f}.')
|
|
|
+
|
|
|
+ if (iteration + 1) % 500 == 0:
|
|
|
+ if self.config['filter'] == 'none':
|
|
|
+ pass
|
|
|
+ elif self.config['filter'] == 'median':
|
|
|
+ x_trial.data = MedianPool2d(kernel_size=3, stride=1, padding=1, same=False)(x_trial)
|
|
|
+ else:
|
|
|
+ raise ValueError()
|
|
|
+
|
|
|
+# grad_log['rec_loss'][iteration] = rec_loss.item()
|
|
|
+# grad_log['input_mse'][iteration] = (x_trial.detach() - x_input).pow(2).mean().item()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ if dryrun:
|
|
|
+ break
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ print(f'Recovery interrupted manually in iteration {iteration}!')
|
|
|
+ pass
|
|
|
+ return x_trial.detach(), labels
|
|
|
+
|
|
|
+ def _gradient_closure(self, optimizer, x_trial, input_gradient, label):
|
|
|
+
|
|
|
+ def closure():
|
|
|
+ optimizer.zero_grad()
|
|
|
+ self.model.zero_grad()
|
|
|
+ loss = self.loss_fn(self.model(x_trial), label)
|
|
|
+ gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
|
|
|
+ rec_loss = reconstruction_costs([gradient], input_gradient,
|
|
|
+ cost_fn=self.config['cost_fn'], indices=self.config['indices'],
|
|
|
+ norm=self.config['norm'], weights=self.config['weights'])
|
|
|
+
|
|
|
+ if self.config['total_variation'] > 0:
|
|
|
+ rec_loss += self.config['total_variation'] * TV(x_trial)
|
|
|
+ rec_loss.backward() # second derivative --> difference between gradients
|
|
|
+ if self.config['signed']:
|
|
|
+ x_trial.grad.sign_()
|
|
|
+ return rec_loss
|
|
|
+ return closure
|
|
|
+
|
|
|
+ def _score_trial(self, x_trial, input_gradient, label):
|
|
|
+ if self.config['scoring_choice'] == 'loss':
|
|
|
+ self.model.zero_grad()
|
|
|
+ x_trial.grad = None
|
|
|
+ loss = self.loss_fn(self.model(x_trial), label)
|
|
|
+ gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
|
|
|
+ return reconstruction_costs([gradient], input_gradient,
|
|
|
+ cost_fn=self.config['cost_fn'], indices=self.config['indices'],
|
|
|
+ norm=self.config['norm'], weights=self.config['weights'])
|
|
|
+ elif self.config['scoring_choice'] == 'tv':
|
|
|
+ return TV(x_trial)
|
|
|
+ elif self.config['scoring_choice'] == 'inception':
|
|
|
+ # We do not care about diversity here!
|
|
|
+ return self.inception(x_trial)
|
|
|
+ elif self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']:
|
|
|
+ return 0.0
|
|
|
+ else:
|
|
|
+ raise ValueError()
|
|
|
+
|
|
|
+ def _average_trials(self, x, labels, input_data, stats):
|
|
|
+ print(f'Computing a combined result via {self.config["scoring_choice"]} ...')
|
|
|
+ if self.config['scoring_choice'] == 'pixelmedian':
|
|
|
+ x_optimal, _ = x.median(dim=0, keepdims=False)
|
|
|
+ elif self.config['scoring_choice'] == 'pixelmean':
|
|
|
+ x_optimal = x.mean(dim=0, keepdims=False)
|
|
|
+
|
|
|
+ self.model.zero_grad()
|
|
|
+ if self.reconstruct_label:
|
|
|
+ labels = self.model(x_optimal).softmax(dim=1)
|
|
|
+ loss = self.loss_fn(self.model(x_optimal), labels)
|
|
|
+ gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
|
|
|
+ stats['opt'] = reconstruction_costs([gradient], input_data,
|
|
|
+ cost_fn=self.config['cost_fn'],
|
|
|
+ indices=self.config['indices'],
|
|
|
+ norm=self.config['norm'],
|
|
|
+ weights=self.config['weights'])
|
|
|
+ print(f'Optimal result score: {stats["opt"]:2.4f}')
|
|
|
+ return x_optimal, stats
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class FedAvgReconstructor(GradientReconstructor):
|
|
|
+ """Reconstruct an image from weights after n gradient descent steps."""
|
|
|
+
|
|
|
+ def __init__(self, model, mean_std=(0.0, 1.0), local_steps=2, local_lr=1e-4,
|
|
|
+ config=DEFAULT_CONFIG, num_images=1, use_updates=True):
|
|
|
+ """Initialize with model, (mean, std) and config."""
|
|
|
+ super().__init__(model, mean_std, config, num_images)
|
|
|
+ self.local_steps = local_steps
|
|
|
+ self.local_lr = local_lr
|
|
|
+ self.use_updates = use_updates
|
|
|
+
|
|
|
+ def _gradient_closure(self, optimizer, x_trial, input_parameters, labels):
|
|
|
+ def closure():
|
|
|
+ optimizer.zero_grad()
|
|
|
+ self.model.zero_grad()
|
|
|
+ parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn,
|
|
|
+ local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates)
|
|
|
+ rec_loss = reconstruction_costs([gradient], input_gradient,
|
|
|
+ cost_fn=self.config['cost_fn'], indices=self.config['indices'],
|
|
|
+ norm=self.config['norm'], weights=self.config['weights'])
|
|
|
+
|
|
|
+ if self.config['total_variation'] > 0:
|
|
|
+ rec_loss += self.config['total_variation'] * TV(x_trial)
|
|
|
+ rec_loss.backward()
|
|
|
+ if self.config['signed']:
|
|
|
+ x_trial.grad.sign_()
|
|
|
+ return rec_loss
|
|
|
+ return closure
|
|
|
+
|
|
|
+ def _score_trial(self, x_trial, input_parameters, labels):
|
|
|
+ if self.config['scoring_choice'] == 'loss':
|
|
|
+ self.model.zero_grad()
|
|
|
+ parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn,
|
|
|
+ local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates)
|
|
|
+ return reconstruction_costs([gradient], input_gradient,
|
|
|
+ cost_fn=self.config['cost_fn'], indices=self.config['indices'],
|
|
|
+ norm=self.config['norm'], weights=self.config['weights'])
|
|
|
+ elif self.config['scoring_choice'] == 'tv':
|
|
|
+ return TV(x_trial)
|
|
|
+ elif self.config['scoring_choice'] == 'inception':
|
|
|
+ # We do not care about diversity here!
|
|
|
+ return self.inception(x_trial)
|
|
|
+
|
|
|
+
|
|
|
+def loss_steps(model, inputs, labels, loss_fn=torch.nn.CrossEntropyLoss(), lr=1e-4, local_steps=4, use_updates=True, batch_size=0):
|
|
|
+ """Take a few gradient descent steps to fit the model to the given input."""
|
|
|
+ patched_model = MetaMonkey(model)
|
|
|
+ if use_updates:
|
|
|
+ patched_model_origin = deepcopy(patched_model)
|
|
|
+ for i in range(local_steps):
|
|
|
+ if batch_size == 0:
|
|
|
+ outputs = patched_model(inputs, patched_model.parameters)
|
|
|
+ labels_ = labels
|
|
|
+ else:
|
|
|
+ idx = i % (inputs.shape[0] // batch_size)
|
|
|
+ outputs = patched_model(inputs[idx * batch_size:(idx + 1) * batch_size], patched_model.parameters)
|
|
|
+ labels_ = labels[idx * batch_size:(idx + 1) * batch_size]
|
|
|
+ loss = loss_fn(outputs, labels_).sum()
|
|
|
+ grad = torch.autograd.grad(loss, patched_model.parameters.values(),
|
|
|
+ retain_graph=True, create_graph=True, only_inputs=True)
|
|
|
+
|
|
|
+ patched_model.parameters = OrderedDict((name, param - lr * grad_part)
|
|
|
+ for ((name, param), grad_part)
|
|
|
+ in zip(patched_model.parameters.items(), grad))
|
|
|
+
|
|
|
+ if use_updates:
|
|
|
+ patched_model.parameters = OrderedDict((name, param - param_origin)
|
|
|
+ for ((name, param), (name_origin, param_origin))
|
|
|
+ in zip(patched_model.parameters.items(), patched_model_origin.parameters.items()))
|
|
|
+ return list(patched_model.parameters.values())
|
|
|
+
|
|
|
+def reconstruction_costs(gradients, input_gradient, cost_fn='l2', indices='def', norm='none', weights='equal'):
|
|
|
+ """Input gradient is given data."""
|
|
|
+
|
|
|
+ if isinstance(indices, list):
|
|
|
+ pass
|
|
|
+ elif indices == 'def':
|
|
|
+ indices = torch.arange(len(input_gradient))
|
|
|
+ elif indices == 'batch':
|
|
|
+ indices = torch.randperm(len(input_gradient))[:8]
|
|
|
+ elif indices == 'topk-1':
|
|
|
+ _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 4)
|
|
|
+ elif indices == 'top10':
|
|
|
+ _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 10)
|
|
|
+ elif indices == 'top50':
|
|
|
+ _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 50)
|
|
|
+ elif indices in ['first', 'first4']:
|
|
|
+ indices = torch.arange(0, 4)
|
|
|
+ elif indices == 'first5':
|
|
|
+ indices = torch.arange(0, 5)
|
|
|
+ elif indices == 'first10':
|
|
|
+ indices = torch.arange(0, 10)
|
|
|
+ elif indices == 'first50':
|
|
|
+ indices = torch.arange(0, 50)
|
|
|
+ elif indices == 'last5':
|
|
|
+ indices = torch.arange(len(input_gradient))[-5:]
|
|
|
+ elif indices == 'last10':
|
|
|
+ indices = torch.arange(len(input_gradient))[-10:]
|
|
|
+ elif indices == 'last50':
|
|
|
+ indices = torch.arange(len(input_gradient))[-50:]
|
|
|
+ # customise the pnorm choose
|
|
|
+ elif indices == 'custom_bottom':
|
|
|
+ _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), norm, largest=False)
|
|
|
+ elif indices == 'custom_top':
|
|
|
+ _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), norm, largest=True)
|
|
|
+ else:
|
|
|
+ raise ValueError()
|
|
|
+
|
|
|
+ ex = input_gradient[0]
|
|
|
+ if weights == 'linear':
|
|
|
+ weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient)
|
|
|
+ elif weights == 'exp':
|
|
|
+ weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device)
|
|
|
+ weights = weights.softmax(dim=0)
|
|
|
+ weights = weights / weights[0]
|
|
|
+ else:
|
|
|
+ # [1.,1.,1......]
|
|
|
+ weights = input_gradient[0].new_ones(len(input_gradient))
|
|
|
+
|
|
|
+ total_costs = 0
|
|
|
+ for trial_gradient in gradients:
|
|
|
+ pnorm = [0, 0]
|
|
|
+ costs = 0
|
|
|
+
|
|
|
+ for i in indices:
|
|
|
+ if cost_fn == 'l2':
|
|
|
+ costs += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum() * weights[i]
|
|
|
+ elif cost_fn == 'l1':
|
|
|
+ costs += ((trial_gradient[i] - input_gradient[i]).abs()).sum() * weights[i]
|
|
|
+ elif cost_fn == 'max':
|
|
|
+ costs += ((trial_gradient[i] - input_gradient[i]).abs()).max() * weights[i]
|
|
|
+ elif cost_fn == 'sim':
|
|
|
+ costs -= (trial_gradient[i] * input_gradient[i]).sum() * weights[i]
|
|
|
+ pnorm[0] += trial_gradient[i].pow(2).sum() * weights[i]
|
|
|
+ pnorm[1] += input_gradient[i].pow(2).sum() * weights[i]
|
|
|
+
|
|
|
+ elif cost_fn == 'simlocal':
|
|
|
+ costs += 1 - torch.nn.functional.cosine_similarity(trial_gradient[i].flatten(),
|
|
|
+ input_gradient[i].flatten(),
|
|
|
+ 0, 1e-10) * weights[i]
|
|
|
+
|
|
|
+ if cost_fn == 'sim':
|
|
|
+ costs = 1 + costs / pnorm[0].sqrt() / pnorm[1].sqrt()
|
|
|
+
|
|
|
+ # Accumulate final costs
|
|
|
+ total_costs += costs
|
|
|
+
|
|
|
+ return total_costs / len(gradients)
|