浏览代码

fix: fix file not found or corrupted

Shellmiao 1 年之前
父节点
当前提交
f005f52ddb
共有 2 个文件被更改,包括 7 次插入6 次删除
  1. 5 4
      main.py
  2. 2 2
      utils/dataloader.py

+ 5 - 4
main.py

@@ -29,7 +29,7 @@ parser = argparse.ArgumentParser(description='Adversarial attack from gradient l
 parser.add_argument('--model', type=str, help='model to perform adversarial attack')
 parser.add_argument('--data', type=str, help='dataset used')
 parser.add_argument('--stack_size', default=4, type=int, help='size use to stack images')
-parser.add_argument('-l','--target_idx', nargs='+', help='list of data index to recontruct')
+parser.add_argument('-l','--target_idx', type=str, help='comma separated list of data index to recontruct')
 parser.add_argument('--save', type=str2bool, nargs='?', const=False, default=True, help='save')
 parser.add_argument('--gpu', type=str2bool, nargs='?', const=False, default=True, help='use gpu')
 
@@ -39,11 +39,12 @@ model_name = args.model
 data = args.data
 stack_size = args.stack_size
 save_output = args.save 
-if args.target_idx is not None: 
-    target_idx = [int(i) for i in args.target_idx]
-else: 
+if args.target_idx is not None:
+    target_idx = [int(i) for i in args.target_idx.split(',')]
+else:
     target_idx = args.target_idx
 
+
 device = 'cpu'
 if args.gpu: 
     device = 'cuda'

+ 2 - 2
utils/dataloader.py

@@ -27,10 +27,10 @@ class DataLoader:
 
         dm = torch.as_tensor(mean)[:, None, None].to(self.device)
         ds = torch.as_tensor(std)[:, None, None].to(self.device)
-        data_root = 'data/cifar_data'
+        data_root = '/Users/shellmiao/Documents/adversarial-attack-from-leakage/data/cifar_data'
 #         data_root = '~/.torch'
         if self.data == 'cifar10': 
-            dataset = datasets.CIFAR10(root=data_root, download=True, train=False, transform=transform)
+            dataset = datasets.CIFAR10(root=data_root, download=False, train=False, transform=transform)
         elif self.data ==  'cifar100': 
             dataset = datasets.CIFAR100(root=data_root, download=True, train=False, transform=transform)
         elif self.data == 'mnist':