reconstruction_algorithms.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. """Mechanisms for image reconstruction from parameter gradients."""
  2. import torch
  3. from collections import defaultdict, OrderedDict
  4. from .modules import MetaMonkey
  5. from .metrics import total_variation as TV
  6. from .metrics import InceptionScore
  7. from .medianfilt import MedianPool2d
  8. from copy import deepcopy
  9. import time
  10. DEFAULT_CONFIG = dict(signed=True,
  11. boxed=True,
  12. cost_fn='sim',
  13. indices='topk-1',
  14. norm='none',
  15. weights='equal',
  16. lr=0.01,
  17. optim='adam',
  18. restarts=128,
  19. max_iterations=8_000,
  20. total_variation=0,
  21. init='randn',
  22. filter='none',
  23. lr_decay=False,
  24. scoring_choice='loss')
  25. def _label_to_onehot(target, num_classes=100):
  26. target = torch.unsqueeze(target, 1)
  27. onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
  28. onehot_target.scatter_(1, target, 1)
  29. return onehot_target
  30. def _validate_config(config):
  31. for key in DEFAULT_CONFIG.keys():
  32. if config.get(key) is None:
  33. config[key] = DEFAULT_CONFIG[key]
  34. for key in config.keys():
  35. if DEFAULT_CONFIG.get(key) is None:
  36. raise ValueError(f'Deprecated key in config dict: {key}!')
  37. return config
  38. class GradientReconstructor():
  39. """Instantiate a reconstruction algorithm."""
  40. def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images=1):
  41. """Initialize with algorithm setup."""
  42. self.config = _validate_config(config)
  43. self.model = model
  44. self.setup = dict(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype)
  45. self.mean_std = mean_std
  46. self.num_images = num_images
  47. if self.config['scoring_choice'] == 'inception':
  48. self.inception = InceptionScore(batch_size=1, setup=self.setup)
  49. self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
  50. # self.iDLG = True
  51. # self.grad_log = defaultdict(dict)
  52. def reconstruct(self, input_data, labels, input_img, img_shape=(3, 32, 32), dryrun=False, eval=True):
  53. """Reconstruct image from gradient."""
  54. start_time = time.time()
  55. if eval:
  56. self.model.eval()
  57. stats = defaultdict(list)
  58. # grad_log = self.grad_log
  59. x = self._init_images(img_shape)
  60. scores = torch.zeros(self.config['restarts'])
  61. # if labels is None:
  62. # if self.num_images == 1 and self.iDLG:
  63. # # iDLG trick:
  64. # last_weight_min = torch.argmin(torch.sum(input_data[-2], dim=-1), dim=-1)
  65. # labels = last_weight_min.detach().reshape((1,)).requires_grad_(False)
  66. # self.reconstruct_label = False
  67. # else:
  68. # # DLG label recovery
  69. # # However this also improves conditioning for some LBFGS cases
  70. # self.reconstruct_label = True
  71. # def loss_fn(pred, labels):
  72. # labels = torch.nn.functional.softmax(labels, dim=-1)
  73. # return torch.mean(torch.sum(- labels * torch.nn.functional.log_softmax(pred, dim=-1), 1))
  74. # self.loss_fn = loss_fn
  75. assert labels.shape[0] == self.num_images
  76. self.reconstruct_label = False
  77. try:
  78. for trial in range(self.config['restarts']):
  79. x_trial, labels = self._run_trial(x[trial], input_data, labels, input_img, dryrun=dryrun)
  80. # Finalize
  81. scores[trial] = self._score_trial(x_trial, input_data, labels)
  82. x[trial] = x_trial
  83. if dryrun:
  84. break
  85. except KeyboardInterrupt:
  86. print('Trial procedure manually interruped.')
  87. pass
  88. # Choose optimal result:
  89. if self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']:
  90. x_optimal, stats = self._average_trials(x, labels, input_data, stats)
  91. else:
  92. print('Choosing optimal result ...')
  93. scores = scores[torch.isfinite(scores)] # guard against NaN/-Inf scores?
  94. optimal_index = torch.argmin(scores)
  95. print(f'Optimal result score: {scores[optimal_index]:2.4f}')
  96. stats['opt'] = scores[optimal_index].item()
  97. x_optimal = x[optimal_index]
  98. print(f'Total time: {time.time()-start_time}.')
  99. return x_optimal.detach(), stats
  100. # return x_optimal.detach(), stats, grad_log
  101. def _init_images(self, img_shape):
  102. if self.config['init'] == 'randn':
  103. return torch.randn((self.config['restarts'], self.num_images, *img_shape), **self.setup)
  104. elif self.config['init'] == 'rand':
  105. return (torch.rand((self.config['restarts'], self.num_images, *img_shape), **self.setup) - 0.5) * 2
  106. elif self.config['init'] == 'zeros':
  107. return torch.zeros((self.config['restarts'], self.num_images, *img_shape), **self.setup)
  108. else:
  109. raise ValueError()
  110. def _run_trial(self, x_trial, input_data, labels, x_input, dryrun=False): # x_input - ground truth
  111. x_trial.requires_grad = True
  112. # grad_log = self.grad_log
  113. if self.reconstruct_label:
  114. output_test = self.model(x_trial)
  115. labels = torch.randn(output_test.shape[1]).to(**self.setup).requires_grad_(True)
  116. if self.config['optim'] == 'adam':
  117. optimizer = torch.optim.Adam([x_trial, labels], lr=self.config['lr'])
  118. elif self.config['optim'] == 'sgd': # actually gd
  119. optimizer = torch.optim.SGD([x_trial, labels], lr=0.01, momentum=0.9, nesterov=True)
  120. elif self.config['optim'] == 'LBFGS':
  121. optimizer = torch.optim.LBFGS([x_trial, labels])
  122. else:
  123. raise ValueError()
  124. else:
  125. if self.config['optim'] == 'adam':
  126. optimizer = torch.optim.Adam([x_trial], lr=self.config['lr'])
  127. elif self.config['optim'] == 'sgd': # actually gd
  128. optimizer = torch.optim.SGD([x_trial], lr=0.01, momentum=0.9, nesterov=True)
  129. elif self.config['optim'] == 'LBFGS':
  130. optimizer = torch.optim.LBFGS([x_trial])
  131. else:
  132. raise ValueError()
  133. max_iterations = self.config['max_iterations']
  134. dm, ds = self.mean_std
  135. if self.config['lr_decay']:
  136. scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
  137. milestones=[max_iterations // 2.667, max_iterations // 1.6,
  138. max_iterations // 1.142], gamma=0.1) # 3/8 5/8 7/8
  139. try:
  140. for iteration in range(max_iterations):
  141. closure = self._gradient_closure(optimizer, x_trial, input_data, labels)
  142. rec_loss = optimizer.step(closure)
  143. if self.config['lr_decay']:
  144. scheduler.step()
  145. with torch.no_grad():
  146. # Project into image space
  147. if self.config['boxed']:
  148. x_trial.data = torch.max(torch.min(x_trial, (1 - dm) / ds), -dm / ds)
  149. if iteration + 1 == max_iterations: # or iteration % 100 == 0:
  150. print(f'It: {iteration + 1}. Rec. loss: {rec_loss.item():2.4f}.')
  151. if (iteration + 1) % 500 == 0:
  152. if self.config['filter'] == 'none':
  153. pass
  154. elif self.config['filter'] == 'median':
  155. x_trial.data = MedianPool2d(kernel_size=3, stride=1, padding=1, same=False)(x_trial)
  156. else:
  157. raise ValueError()
  158. # grad_log['rec_loss'][iteration] = rec_loss.item()
  159. # grad_log['input_mse'][iteration] = (x_trial.detach() - x_input).pow(2).mean().item()
  160. if dryrun:
  161. break
  162. except KeyboardInterrupt:
  163. print(f'Recovery interrupted manually in iteration {iteration}!')
  164. pass
  165. return x_trial.detach(), labels
  166. def _gradient_closure(self, optimizer, x_trial, input_gradient, label):
  167. def closure():
  168. optimizer.zero_grad()
  169. self.model.zero_grad()
  170. loss = self.loss_fn(self.model(x_trial), label)
  171. gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
  172. rec_loss = reconstruction_costs([gradient], input_gradient,
  173. cost_fn=self.config['cost_fn'], indices=self.config['indices'],
  174. norm=self.config['norm'], weights=self.config['weights'])
  175. if self.config['total_variation'] > 0:
  176. rec_loss += self.config['total_variation'] * TV(x_trial)
  177. rec_loss.backward() # second derivative --> difference between gradients
  178. if self.config['signed']:
  179. x_trial.grad.sign_()
  180. return rec_loss
  181. return closure
  182. def _score_trial(self, x_trial, input_gradient, label):
  183. if self.config['scoring_choice'] == 'loss':
  184. self.model.zero_grad()
  185. x_trial.grad = None
  186. loss = self.loss_fn(self.model(x_trial), label)
  187. gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
  188. return reconstruction_costs([gradient], input_gradient,
  189. cost_fn=self.config['cost_fn'], indices=self.config['indices'],
  190. norm=self.config['norm'], weights=self.config['weights'])
  191. elif self.config['scoring_choice'] == 'tv':
  192. return TV(x_trial)
  193. elif self.config['scoring_choice'] == 'inception':
  194. # We do not care about diversity here!
  195. return self.inception(x_trial)
  196. elif self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']:
  197. return 0.0
  198. else:
  199. raise ValueError()
  200. def _average_trials(self, x, labels, input_data, stats):
  201. print(f'Computing a combined result via {self.config["scoring_choice"]} ...')
  202. if self.config['scoring_choice'] == 'pixelmedian':
  203. x_optimal, _ = x.median(dim=0, keepdims=False)
  204. elif self.config['scoring_choice'] == 'pixelmean':
  205. x_optimal = x.mean(dim=0, keepdims=False)
  206. self.model.zero_grad()
  207. if self.reconstruct_label:
  208. labels = self.model(x_optimal).softmax(dim=1)
  209. loss = self.loss_fn(self.model(x_optimal), labels)
  210. gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
  211. stats['opt'] = reconstruction_costs([gradient], input_data,
  212. cost_fn=self.config['cost_fn'],
  213. indices=self.config['indices'],
  214. norm=self.config['norm'],
  215. weights=self.config['weights'])
  216. print(f'Optimal result score: {stats["opt"]:2.4f}')
  217. return x_optimal, stats
  218. class FedAvgReconstructor(GradientReconstructor):
  219. """Reconstruct an image from weights after n gradient descent steps."""
  220. def __init__(self, model, mean_std=(0.0, 1.0), local_steps=2, local_lr=1e-4,
  221. config=DEFAULT_CONFIG, num_images=1, use_updates=True):
  222. """Initialize with model, (mean, std) and config."""
  223. super().__init__(model, mean_std, config, num_images)
  224. self.local_steps = local_steps
  225. self.local_lr = local_lr
  226. self.use_updates = use_updates
  227. def _gradient_closure(self, optimizer, x_trial, input_parameters, labels):
  228. def closure():
  229. optimizer.zero_grad()
  230. self.model.zero_grad()
  231. parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn,
  232. local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates)
  233. rec_loss = reconstruction_costs([gradient], input_gradient,
  234. cost_fn=self.config['cost_fn'], indices=self.config['indices'],
  235. norm=self.config['norm'], weights=self.config['weights'])
  236. if self.config['total_variation'] > 0:
  237. rec_loss += self.config['total_variation'] * TV(x_trial)
  238. rec_loss.backward()
  239. if self.config['signed']:
  240. x_trial.grad.sign_()
  241. return rec_loss
  242. return closure
  243. def _score_trial(self, x_trial, input_parameters, labels):
  244. if self.config['scoring_choice'] == 'loss':
  245. self.model.zero_grad()
  246. parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn,
  247. local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates)
  248. return reconstruction_costs([gradient], input_gradient,
  249. cost_fn=self.config['cost_fn'], indices=self.config['indices'],
  250. norm=self.config['norm'], weights=self.config['weights'])
  251. elif self.config['scoring_choice'] == 'tv':
  252. return TV(x_trial)
  253. elif self.config['scoring_choice'] == 'inception':
  254. # We do not care about diversity here!
  255. return self.inception(x_trial)
  256. def loss_steps(model, inputs, labels, loss_fn=torch.nn.CrossEntropyLoss(), lr=1e-4, local_steps=4, use_updates=True, batch_size=0):
  257. """Take a few gradient descent steps to fit the model to the given input."""
  258. patched_model = MetaMonkey(model)
  259. if use_updates:
  260. patched_model_origin = deepcopy(patched_model)
  261. for i in range(local_steps):
  262. if batch_size == 0:
  263. outputs = patched_model(inputs, patched_model.parameters)
  264. labels_ = labels
  265. else:
  266. idx = i % (inputs.shape[0] // batch_size)
  267. outputs = patched_model(inputs[idx * batch_size:(idx + 1) * batch_size], patched_model.parameters)
  268. labels_ = labels[idx * batch_size:(idx + 1) * batch_size]
  269. loss = loss_fn(outputs, labels_).sum()
  270. grad = torch.autograd.grad(loss, patched_model.parameters.values(),
  271. retain_graph=True, create_graph=True, only_inputs=True)
  272. patched_model.parameters = OrderedDict((name, param - lr * grad_part)
  273. for ((name, param), grad_part)
  274. in zip(patched_model.parameters.items(), grad))
  275. if use_updates:
  276. patched_model.parameters = OrderedDict((name, param - param_origin)
  277. for ((name, param), (name_origin, param_origin))
  278. in zip(patched_model.parameters.items(), patched_model_origin.parameters.items()))
  279. return list(patched_model.parameters.values())
  280. def reconstruction_costs(gradients, input_gradient, cost_fn='l2', indices='def', norm='none', weights='equal'):
  281. """Input gradient is given data."""
  282. if isinstance(indices, list):
  283. pass
  284. elif indices == 'def':
  285. indices = torch.arange(len(input_gradient))
  286. elif indices == 'batch':
  287. indices = torch.randperm(len(input_gradient))[:8]
  288. elif indices == 'topk-1':
  289. _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 4)
  290. elif indices == 'top10':
  291. _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 10)
  292. elif indices == 'top50':
  293. _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 50)
  294. elif indices in ['first', 'first4']:
  295. indices = torch.arange(0, 4)
  296. elif indices == 'first5':
  297. indices = torch.arange(0, 5)
  298. elif indices == 'first10':
  299. indices = torch.arange(0, 10)
  300. elif indices == 'first50':
  301. indices = torch.arange(0, 50)
  302. elif indices == 'last5':
  303. indices = torch.arange(len(input_gradient))[-5:]
  304. elif indices == 'last10':
  305. indices = torch.arange(len(input_gradient))[-10:]
  306. elif indices == 'last50':
  307. indices = torch.arange(len(input_gradient))[-50:]
  308. # customise the pnorm choose
  309. elif indices == 'custom_bottom':
  310. _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), norm, largest=False)
  311. elif indices == 'custom_top':
  312. _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), norm, largest=True)
  313. else:
  314. raise ValueError()
  315. ex = input_gradient[0]
  316. if weights == 'linear':
  317. weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient)
  318. elif weights == 'exp':
  319. weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device)
  320. weights = weights.softmax(dim=0)
  321. weights = weights / weights[0]
  322. else:
  323. # [1.,1.,1......]
  324. weights = input_gradient[0].new_ones(len(input_gradient))
  325. total_costs = 0
  326. for trial_gradient in gradients:
  327. pnorm = [0, 0]
  328. costs = 0
  329. for i in indices:
  330. if cost_fn == 'l2':
  331. costs += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum() * weights[i]
  332. elif cost_fn == 'l1':
  333. costs += ((trial_gradient[i] - input_gradient[i]).abs()).sum() * weights[i]
  334. elif cost_fn == 'max':
  335. costs += ((trial_gradient[i] - input_gradient[i]).abs()).max() * weights[i]
  336. elif cost_fn == 'sim':
  337. costs -= (trial_gradient[i] * input_gradient[i]).sum() * weights[i]
  338. pnorm[0] += trial_gradient[i].pow(2).sum() * weights[i]
  339. pnorm[1] += input_gradient[i].pow(2).sum() * weights[i]
  340. elif cost_fn == 'simlocal':
  341. costs += 1 - torch.nn.functional.cosine_similarity(trial_gradient[i].flatten(),
  342. input_gradient[i].flatten(),
  343. 0, 1e-10) * weights[i]
  344. if cost_fn == 'sim':
  345. costs = 1 + costs / pnorm[0].sqrt() / pnorm[1].sqrt()
  346. # Accumulate final costs
  347. total_costs += costs
  348. return total_costs / len(gradients)