main.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import argparse
  2. import os
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from torch.utils.data import DataLoader
  7. from torch.autograd import grad
  8. import torchvision
  9. from torchvision import datasets, transforms
  10. from torchvision.utils import save_image
  11. import torchvision.models as models
  12. import inversefed
  13. from utils.dataloader import DataLoader
  14. from utils.stackeddata import StackedData
  15. # inverting gradients algorithm from https://github.com/JonasGeiping/invertinggradients
  16. def str2bool(v):
  17. if isinstance(v, bool):
  18. return v
  19. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  20. return True
  21. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  22. return False
  23. else:
  24. raise argparse.ArgumentTypeError('Boolean value expected.')
  25. parser = argparse.ArgumentParser(description='Adversarial attack from gradient leakage')
  26. parser.add_argument('--model', type=str, help='model to perform adversarial attack')
  27. parser.add_argument('--data', type=str, help='dataset used')
  28. parser.add_argument('--stack_size', default=4, type=int, help='size use to stack images')
  29. parser.add_argument('-l','--target_idx', type=str, help='comma separated list of data index to recontruct')
  30. parser.add_argument('--save', type=str2bool, nargs='?', const=False, default=True, help='save')
  31. parser.add_argument('--gpu', type=str2bool, nargs='?', const=False, default=True, help='use gpu')
  32. args = parser.parse_args()
  33. model_name = args.model
  34. data = args.data
  35. stack_size = args.stack_size
  36. save_output = args.save
  37. if args.target_idx is not None:
  38. target_idx = [int(i) for i in args.target_idx.split(',')]
  39. else:
  40. target_idx = args.target_idx
  41. device = 'cpu'
  42. if args.gpu:
  43. device = 'cuda'
  44. print("Running on %s" % device)
  45. def val_model(dataset, model, criterion):
  46. # evaluate trained model, record wrongly predicted index
  47. model.eval()
  48. # record wrong pred index
  49. index_ls = []
  50. with torch.no_grad():
  51. val_loss, val_corrects = 0, 0
  52. for batch_idx, (inputs, labels) in enumerate(dataset):
  53. inputs = inputs.unsqueeze(dim=0).to(device)
  54. labels = torch.as_tensor([labels]).to(device)
  55. outputs = model(inputs)
  56. loss = criterion(outputs, labels)
  57. _, preds = torch.max(outputs, 1)
  58. val_loss += loss.item() * inputs.size(0) # mutiply by number of batches
  59. val_corrects += torch.sum(preds == labels.data)
  60. if (preds != labels.data):
  61. index_ls.append(batch_idx)
  62. total_loss = val_loss / len(dataset)
  63. total_acc = val_corrects.double() / len(dataset)
  64. print('{} Loss: {:.4f} Acc: {:.4f}'.format('val', total_loss, total_acc))
  65. return index_ls
  66. dataloader = DataLoader(data, device)
  67. dataset, data_shape, classes, (dm, ds) = dataloader.get_data_info()
  68. model = models.resnet18(pretrained=True) # use pretrained model from torchvision
  69. model.fc = nn.Linear(512, len(classes)) # reinitialize model output: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
  70. model = model.to(device)
  71. model.eval()
  72. criterion = nn.CrossEntropyLoss()
  73. stack_data = StackedData(stack_size=4, model_name=model_name, dataset_name=data, dataset=dataset, save_output=save_output, device=device)
  74. if target_idx is None:
  75. wrong_pred_idx = val_model(dataset, model, criterion)
  76. else:
  77. if isinstance(target_idx, (list))==False:
  78. wrong_pred_idx = [target_idx]
  79. else:
  80. wrong_pred_idx = target_idx
  81. stacked_data_d = stack_data.create_stacked_data(wrong_pred_idx)
  82. for i in range(len(stacked_data_d['gt_img'])):
  83. gt_img, gt_label, img_idx = stacked_data_d['gt_img'][i], stacked_data_d['gt_label'][i], stacked_data_d['img_index'][i]
  84. stack_pred = model(gt_img)
  85. target_loss = criterion(stack_pred, gt_label)
  86. input_grad = grad(target_loss, model.parameters())
  87. input_grad =[grad.detach() for grad in input_grad]
  88. # default configuration from inversefed
  89. config = dict(signed=True,
  90. boxed=False,
  91. cost_fn='sim',
  92. indices='def',
  93. norm='none',
  94. weights='equal',
  95. lr=0.1,
  96. optim='adam',
  97. restarts=1,
  98. max_iterations=1200,
  99. total_variation=0.1,
  100. init='randn',
  101. filter='none',
  102. lr_decay=True,
  103. scoring_choice='loss')
  104. rec_machine = inversefed.GradientReconstructor(model, (dm, ds), config, num_images=gt_img.shape[0])
  105. results = rec_machine.reconstruct(input_grad, gt_label, gt_img ,img_shape=data_shape)
  106. output_img, stats = results
  107. rec_pred = model(output_img)
  108. print('Predictions for recontructed images: ', [classes[l] for l in torch.max(rec_pred, axis=1)[1]])
  109. stack_data.grid_plot(img_idx, output_img, rec_pred, dm, ds)