Browse Source

first commit

Jiaqi0602 3 năm trước cách đây
commit
ed4dd791b5

+ 3 - 0
.gitignore

@@ -0,0 +1,3 @@
+__pycache__/
+.ipynb_checkpoints
+models/ 

+ 31 - 0
README.md

@@ -0,0 +1,31 @@
+# From Gradient Leakage to Adversarial Attacks in Federated Learning
+
+By utilizing an existing privacy
+breaking algorithm which inverts gradients of models to reconstruct the input data, the data reconstructed from inverting gradients algorithm reveals the vulnerabilities of models in representation learning.
+
+In this work, we utilize the inverting gradients algorithm proposed in [Inverting Gradients - How easy is it to break Privacy in Federated Learning?](https://arxiv.org/pdf/2003.14053.pdf) to reconstruct the data that could lead to possible threats in classification task. By stacking one wrongly predicted image into different batch sizes, then use the stacked images as input of the existing gradients inverting algorithm will result in reconstruction of distorted images that can be correctly predicted by the attacked model.
+
+<pic> 
+
+## Prerequisites
+Required libraries:
+```bash
+Python>=3.7
+pytorch=1.5.0
+torchvision=0.6.0
+```
+## Code
+```python
+python main.py --model "resnet18" --data "cifar10" stack_size 4 -ls 1001,770,123 --save True --gpu True
+```
+
+Implementation for ResNet-18 trained with CIFAR10 can be found [HERE](link to cifar notebook) and with VGGFACE2 can be found [HERE](link to vgg notebook)
+
+#### Quick reproduction for CIFAR10 dataset: 
+You can download pretrained model from [HERE](https://github.com/huyvnphan/PyTorch_CIFAR10) then replace the torchvision models.
+
+
+## Reference: 
+- [Inverting Gradients - How easy is it to break Privacy in Federated Learning?](https://github.com/JonasGeiping/invertinggradients)
+- [Deep Leakage From Gradients](https://github.com/mit-han-lab/dlg) 
+- [PyTorch models trained on CIFAR-10 dataset](https://github.com/huyvnphan/PyTorch_CIFAR10)

+ 3 - 0
data/.gitignore

@@ -0,0 +1,3 @@
+*
+*/
+!.gitignore

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 69 - 0
demo - CIFAR10.ipynb


Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 71 - 0
demo - VGGFACE2.ipynb


BIN
image/graph1.jpg


BIN
image/graph2.jpg


BIN
image/rec_output.JPG


+ 9 - 0
inversefed/__init__.py

@@ -0,0 +1,9 @@
+"""Library of routines."""
+
+from inversefed import utils
+from .modules import MetaMonkey
+from .optimization_strategy import training_strategy
+from .reconstruction_algorithms import GradientReconstructor, FedAvgReconstructor
+from inversefed import metrics
+
+__all__ = ['utils', 'MetaMonkey', 'training_strategy', 'GradientReconstructor', 'FedAvgReconstructor']

+ 16 - 0
inversefed/consts.py

@@ -0,0 +1,16 @@
+"""Setup constants, ymmv."""
+
+PIN_MEMORY = True
+NON_BLOCKING = False
+BENCHMARK = True
+MULTITHREAD_DATAPROCESSING = 4
+
+
+cifar10_mean = [0.4914672374725342, 0.4822617471218109, 0.4467701315879822]
+cifar10_std = [0.24703224003314972, 0.24348513782024384, 0.26158785820007324]
+cifar100_mean = [0.5071598291397095, 0.4866936206817627, 0.44120192527770996]
+cifar100_std = [0.2673342823982239, 0.2564384639263153, 0.2761504650115967]
+mnist_mean = (0.13066373765468597,)
+mnist_std = (0.30810782313346863,)
+imagenet_mean = [0.485, 0.456, 0.406]
+imagenet_std = [0.229, 0.224, 0.225]

+ 54 - 0
inversefed/medianfilt.py

@@ -0,0 +1,54 @@
+"""This is code for median pooling from https://gist.github.com/rwightman.
+
+https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598
+"""
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.utils import _pair, _quadruple
+
+
+class MedianPool2d(nn.Module):
+    """Median pool (usable as median filter when stride=1) module.
+
+    Args:
+         kernel_size: size of pooling kernel, int or 2-tuple
+         stride: pool stride, int or 2-tuple
+         padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
+         same: override padding and enforce same padding, boolean
+    """
+
+    def __init__(self, kernel_size=3, stride=1, padding=0, same=True):
+        """Initialize with kernel_size, stride, padding."""
+        super().__init__()
+        self.k = _pair(kernel_size)
+        self.stride = _pair(stride)
+        self.padding = _quadruple(padding)  # convert to l, r, t, b
+        self.same = same
+
+    def _padding(self, x):
+        if self.same:
+            ih, iw = x.size()[2:]
+            if ih % self.stride[0] == 0:
+                ph = max(self.k[0] - self.stride[0], 0)
+            else:
+                ph = max(self.k[0] - (ih % self.stride[0]), 0)
+            if iw % self.stride[1] == 0:
+                pw = max(self.k[1] - self.stride[1], 0)
+            else:
+                pw = max(self.k[1] - (iw % self.stride[1]), 0)
+            pl = pw // 2
+            pr = pw - pl
+            pt = ph // 2
+            pb = ph - pt
+            padding = (pl, pr, pt, pb)
+        else:
+            padding = self.padding
+        return padding
+
+    def forward(self, x):
+        # using existing pytorch functions and tensor ops so that we get autograd,
+        # would likely be more efficient to implement from scratch at C/Cuda level
+        x = F.pad(x, self._padding(x), mode='reflect')
+        x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
+        x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
+        return x

+ 106 - 0
inversefed/metrics.py

@@ -0,0 +1,106 @@
+"""This is code based on https://sudomake.ai/inception-score-explained/."""
+import torch
+import torchvision
+
+from collections import defaultdict
+
+class InceptionScore(torch.nn.Module):
+    """Class that manages and returns the inception score of images."""
+
+    def __init__(self, batch_size=32, setup=dict(device=torch.device('cpu'), dtype=torch.float)):
+        """Initialize with setup and target inception batch size."""
+        super().__init__()
+        self.preprocessing = torch.nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False)
+        self.model = torchvision.models.inception_v3(pretrained=True).to(**setup)
+        self.model.eval()
+        self.batch_size = batch_size
+
+    def forward(self, image_batch):
+        """Image batch should have dimensions BCHW and should be normalized.
+
+        B should be divisible by self.batch_size.
+        """
+        B, C, H, W = image_batch.shape
+        batches = B // self.batch_size
+        scores = []
+        for batch in range(batches):
+            input = self.preprocessing(image_batch[batch * self.batch_size: (batch + 1) * self.batch_size])
+            scores.append(self.model(input))
+        prob_yx = torch.nn.functional.softmax(torch.cat(scores, 0), dim=1)
+        entropy = torch.where(prob_yx > 0, -prob_yx * prob_yx.log(), torch.zeros_like(prob_yx))
+        return entropy.mean()
+
+
+def psnr(img_batch, ref_batch, batched=False, factor=1.0):
+    """Standard PSNR."""
+    def get_psnr(img_in, img_ref):
+        mse = ((img_in - img_ref)**2).mean()
+        if mse > 0 and torch.isfinite(mse):
+            return (10 * torch.log10(factor**2 / mse))
+        elif not torch.isfinite(mse):
+            return img_batch.new_tensor(float('nan'))
+        else:
+            return img_batch.new_tensor(float('inf'))
+
+    if batched:
+        psnr = get_psnr(img_batch.detach(), ref_batch)
+    else:
+        [B, C, m, n] = img_batch.shape
+        psnrs = []
+        for sample in range(B):
+            psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :]))
+        psnr = torch.stack(psnrs, dim=0).mean()
+
+    return psnr.item()
+
+
+def total_variation(x):
+    """Anisotropic TV."""
+    dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
+    dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
+    return dx + dy
+
+
+
+def activation_errors(model, x1, x2):
+    """Compute activation-level error metrics for every module in the network."""
+    model.eval()
+
+    device = next(model.parameters()).device
+
+    hooks = []
+    data = defaultdict(dict)
+    inputs = torch.cat((x1, x2), dim=0)
+    separator = x1.shape[0]
+
+    def check_activations(self, input, output):
+        module_name = str(*[name for name, mod in model.named_modules() if self is mod])
+        try:
+            layer_inputs = input[0].detach()
+            residual = (layer_inputs[:separator] - layer_inputs[separator:]).pow(2)
+            se_error = residual.sum()
+            mse_error = residual.mean()
+            sim = torch.nn.functional.cosine_similarity(layer_inputs[:separator].flatten(),
+                                                        layer_inputs[separator:].flatten(),
+                                                        dim=0, eps=1e-8).detach()
+            data['se'][module_name] = se_error.item()
+            data['mse'][module_name] = mse_error.item()
+            data['sim'][module_name] = sim.item()
+        except (KeyboardInterrupt, SystemExit):
+            raise
+        except AttributeError:
+            pass
+
+    for name, module in model.named_modules():
+        hooks.append(module.register_forward_hook(check_activations))
+
+    try:
+        outputs = model(inputs.to(device))
+        for hook in hooks:
+            hook.remove()
+    except Exception as e:
+        for hook in hooks:
+            hook.remove()
+        raise
+
+    return data

+ 98 - 0
inversefed/modules.py

@@ -0,0 +1,98 @@
+"""For monkey-patching into meta-learning frameworks."""
+import torch
+import torch.nn.functional as F
+from collections import OrderedDict
+from functools import partial
+import warnings
+
+from .consts import BENCHMARK
+torch.backends.cudnn.benchmark = BENCHMARK
+
+DEBUG = False  # Emit warning messages when patching. Use this to bootstrap new architectures.
+
+class MetaMonkey(torch.nn.Module):
+    """Trace a networks and then replace its module calls with functional calls.
+
+    This allows for backpropagation w.r.t to weights for "normal" PyTorch networks.
+    """
+
+    def __init__(self, net):
+        """Init with network."""
+        super().__init__()
+        self.net = net
+        self.parameters = OrderedDict(net.named_parameters())
+
+
+    def forward(self, inputs, parameters=None):
+        """Live Patch ... :> ..."""
+        # If no parameter dictionary is given, everything is normal
+        if parameters is None:
+            return self.net(inputs)
+
+        # But if not ...
+        param_gen = iter(parameters.values())
+        method_pile = []
+        counter = 0
+
+        for name, module in self.net.named_modules():
+            if isinstance(module, torch.nn.Conv2d):
+                ext_weight = next(param_gen)
+                if module.bias is not None:
+                    ext_bias = next(param_gen)
+                else:
+                    ext_bias = None
+
+                method_pile.append(module.forward)
+                module.forward = partial(F.conv2d, weight=ext_weight, bias=ext_bias, stride=module.stride,
+                                         padding=module.padding, dilation=module.dilation, groups=module.groups)
+            elif isinstance(module, torch.nn.BatchNorm2d):
+                if module.momentum is None:
+                    exponential_average_factor = 0.0
+                else:
+                    exponential_average_factor = module.momentum
+
+                if module.training and module.track_running_stats:
+                    if module.num_batches_tracked is not None:
+                        module.num_batches_tracked += 1
+                        if module.momentum is None:  # use cumulative moving average
+                            exponential_average_factor = 1.0 / float(module.num_batches_tracked)
+                        else:  # use exponential moving average
+                            exponential_average_factor = module.momentum
+
+                ext_weight = next(param_gen)
+                ext_bias = next(param_gen)
+                method_pile.append(module.forward)
+                module.forward = partial(F.batch_norm, running_mean=module.running_mean, running_var=module.running_var,
+                                         weight=ext_weight, bias=ext_bias,
+                                         training=module.training or not module.track_running_stats,
+                                         momentum=exponential_average_factor, eps=module.eps)
+
+            elif isinstance(module, torch.nn.Linear):
+                lin_weights = next(param_gen)
+                lin_bias = next(param_gen)
+                method_pile.append(module.forward)
+                module.forward = partial(F.linear, weight=lin_weights, bias=lin_bias)
+
+            elif next(module.parameters(), None) is None:
+                # Pass over modules that do not contain parameters
+                pass
+            elif isinstance(module, torch.nn.Sequential):
+                # Pass containers
+                pass
+            else:
+                # Warn for other containers
+                if DEBUG:
+                    warnings.warn(f'Patching for module {module.__class__} is not implemented.')
+
+        output = self.net(inputs)
+
+        # Undo Patch
+        for name, module in self.net.named_modules():
+            if isinstance(module, torch.nn.modules.conv.Conv2d):
+                module.forward = method_pile.pop(0)
+            elif isinstance(module, torch.nn.BatchNorm2d):
+                module.forward = method_pile.pop(0)
+            elif isinstance(module, torch.nn.Linear):
+                module.forward = method_pile.pop(0)
+
+        return output

+ 78 - 0
inversefed/optimization_strategy.py

@@ -0,0 +1,78 @@
+"""Optimization setups."""
+
+from dataclasses import dataclass
+
+
+def training_strategy(strategy, lr=None, epochs=None, dryrun=False):
+    """Parse training strategy."""
+    if strategy == 'conservative':
+        defs = ConservativeStrategy(lr, epochs, dryrun)
+    elif strategy == 'adam':
+        defs = AdamStrategy(lr, epochs, dryrun)
+    else:
+        raise ValueError('Unknown training strategy.')
+    return defs
+
+
+@dataclass
+class Strategy:
+    """Default usual parameters, not intended for parsing."""
+
+    epochs : int
+    batch_size : int
+    optimizer : str
+    lr : float
+    scheduler : str
+    weight_decay : float
+    validate : int
+    warmup: bool
+    dryrun : bool
+    dropout : float
+    augmentations : bool
+
+    def __init__(self, lr=None, epochs=None, dryrun=False):
+        """Defaulted parameters. Apply overwrites from args."""
+        if epochs is not None:
+            self.epochs = epochs
+        if lr is not None:
+            self.lr = lr
+        if dryrun:
+            self.dryrun = dryrun
+        self.validate = 10
+
+@dataclass
+class ConservativeStrategy(Strategy):
+    """Default usual parameters, defines a config object."""
+
+    def __init__(self, lr=None, epochs=None, dryrun=False):
+        """Initialize training hyperparameters."""
+        self.lr = 0.1
+        self.epochs = 120
+        self.batch_size = 128
+        self.optimizer = 'SGD'
+        self.scheduler = 'linear'
+        self.warmup = False
+        self.weight_decay : float = 5e-4
+        self.dropout = 0.0
+        self.augmentations = True
+        self.dryrun = False
+        super().__init__(lr=None, epochs=None, dryrun=False)
+
+
+@dataclass
+class AdamStrategy(Strategy):
+    """Start slowly. Use a tame Adam."""
+
+    def __init__(self, lr=None, epochs=None, dryrun=False):
+        """Initialize training hyperparameters."""
+        self.lr = 1e-3 / 10
+        self.epochs = 120
+        self.batch_size = 32
+        self.optimizer = 'AdamW'
+        self.scheduler = 'linear'
+        self.warmup = True
+        self.weight_decay : float = 5e-4
+        self.dropout = 0.0
+        self.augmentations = True
+        self.dryrun = False
+        super().__init__(lr=None, epochs=None, dryrun=False)

+ 405 - 0
inversefed/reconstruction_algorithms.py

@@ -0,0 +1,405 @@
+"""Mechanisms for image reconstruction from parameter gradients."""
+
+import torch
+from collections import defaultdict, OrderedDict
+from .modules import MetaMonkey
+from .metrics import total_variation as TV
+from .metrics import InceptionScore
+from .medianfilt import MedianPool2d
+from copy import deepcopy
+import time
+
+DEFAULT_CONFIG = dict(signed=True,
+                      boxed=True,
+                      cost_fn='sim',
+                      indices='topk-1',
+                      norm='none', 
+                      weights='equal',
+                      lr=0.01,
+                      optim='adam',
+                      restarts=128,
+                      max_iterations=8_000,
+                      total_variation=0,
+                      init='randn',
+                      filter='none',
+                      lr_decay=False,
+                      scoring_choice='loss')
+
+def _label_to_onehot(target, num_classes=100):
+    target = torch.unsqueeze(target, 1)
+    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
+    onehot_target.scatter_(1, target, 1)
+    return onehot_target
+
+def _validate_config(config):
+    for key in DEFAULT_CONFIG.keys():
+        if config.get(key) is None:
+            config[key] = DEFAULT_CONFIG[key]
+    for key in config.keys():
+        if DEFAULT_CONFIG.get(key) is None:
+            raise ValueError(f'Deprecated key in config dict: {key}!')
+    return config
+
+
+class GradientReconstructor():
+    """Instantiate a reconstruction algorithm."""
+
+    def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images=1):
+        """Initialize with algorithm setup."""
+        self.config = _validate_config(config)
+        self.model = model
+        self.setup = dict(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype)
+
+        self.mean_std = mean_std
+        self.num_images = num_images
+
+        if self.config['scoring_choice'] == 'inception':
+            self.inception = InceptionScore(batch_size=1, setup=self.setup)
+
+        self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
+#         self.iDLG = True
+#         self.grad_log = defaultdict(dict)
+
+    def reconstruct(self, input_data, labels, input_img, img_shape=(3, 32, 32), dryrun=False, eval=True):
+        """Reconstruct image from gradient."""
+        start_time = time.time()
+        if eval:
+            self.model.eval()
+            
+
+        stats = defaultdict(list)
+#         grad_log = self.grad_log
+        x = self._init_images(img_shape)
+        scores = torch.zeros(self.config['restarts'])
+        
+#         if labels is None:
+#             if self.num_images == 1 and self.iDLG:
+#                 # iDLG trick:
+#                 last_weight_min = torch.argmin(torch.sum(input_data[-2], dim=-1), dim=-1)
+#                 labels = last_weight_min.detach().reshape((1,)).requires_grad_(False)
+#                 self.reconstruct_label = False
+#             else:
+#                 # DLG label recovery
+#                 # However this also improves conditioning for some LBFGS cases
+#                 self.reconstruct_label = True
+
+#                 def loss_fn(pred, labels):
+#                     labels = torch.nn.functional.softmax(labels, dim=-1)
+#                     return torch.mean(torch.sum(- labels * torch.nn.functional.log_softmax(pred, dim=-1), 1))
+#                 self.loss_fn = loss_fn
+
+        assert labels.shape[0] == self.num_images
+        self.reconstruct_label = False
+
+        try:
+            for trial in range(self.config['restarts']):
+                x_trial, labels = self._run_trial(x[trial], input_data, labels, input_img, dryrun=dryrun)
+                # Finalize
+                scores[trial] = self._score_trial(x_trial, input_data, labels)
+                x[trial] = x_trial
+                if dryrun:
+                    break
+        except KeyboardInterrupt:
+            print('Trial procedure manually interruped.')
+            pass
+
+        # Choose optimal result:
+        if self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']:
+            x_optimal, stats = self._average_trials(x, labels, input_data, stats)
+        else:
+            print('Choosing optimal result ...')
+            scores = scores[torch.isfinite(scores)]  # guard against NaN/-Inf scores?
+            optimal_index = torch.argmin(scores)
+            print(f'Optimal result score: {scores[optimal_index]:2.4f}')
+            stats['opt'] = scores[optimal_index].item()
+            x_optimal = x[optimal_index]
+
+        print(f'Total time: {time.time()-start_time}.')
+        return x_optimal.detach(), stats
+#         return x_optimal.detach(), stats, grad_log
+
+    def _init_images(self, img_shape):
+        if self.config['init'] == 'randn':
+            return torch.randn((self.config['restarts'], self.num_images, *img_shape), **self.setup)
+        elif self.config['init'] == 'rand':
+            return (torch.rand((self.config['restarts'], self.num_images, *img_shape), **self.setup) - 0.5) * 2
+        elif self.config['init'] == 'zeros':
+            return torch.zeros((self.config['restarts'], self.num_images, *img_shape), **self.setup)
+        else:
+            raise ValueError()
+
+    def _run_trial(self, x_trial, input_data, labels, x_input, dryrun=False):  # x_input - ground truth
+        x_trial.requires_grad = True
+#         grad_log = self.grad_log
+        if self.reconstruct_label:
+            output_test = self.model(x_trial)
+            labels = torch.randn(output_test.shape[1]).to(**self.setup).requires_grad_(True)
+
+            if self.config['optim'] == 'adam':
+                optimizer = torch.optim.Adam([x_trial, labels], lr=self.config['lr'])
+            elif self.config['optim'] == 'sgd':  # actually gd
+                optimizer = torch.optim.SGD([x_trial, labels], lr=0.01, momentum=0.9, nesterov=True)
+            elif self.config['optim'] == 'LBFGS':
+                optimizer = torch.optim.LBFGS([x_trial, labels])
+            else:
+                raise ValueError()
+        else:
+            if self.config['optim'] == 'adam':
+                optimizer = torch.optim.Adam([x_trial], lr=self.config['lr'])
+            elif self.config['optim'] == 'sgd':  # actually gd
+                optimizer = torch.optim.SGD([x_trial], lr=0.01, momentum=0.9, nesterov=True)
+            elif self.config['optim'] == 'LBFGS':
+                optimizer = torch.optim.LBFGS([x_trial])
+            else:
+                raise ValueError()
+        
+        max_iterations = self.config['max_iterations']
+        dm, ds = self.mean_std
+        if self.config['lr_decay']:
+            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
+                                                             milestones=[max_iterations // 2.667, max_iterations // 1.6,
+
+                                                                         max_iterations // 1.142], gamma=0.1)   # 3/8 5/8 7/8
+        try:
+            for iteration in range(max_iterations):
+                closure = self._gradient_closure(optimizer, x_trial, input_data, labels)
+                rec_loss = optimizer.step(closure)
+
+                if self.config['lr_decay']:
+                    scheduler.step()
+
+                with torch.no_grad():
+                    # Project into image space
+                    if self.config['boxed']:
+                        x_trial.data = torch.max(torch.min(x_trial, (1 - dm) / ds), -dm / ds)
+                    if iteration + 1 == max_iterations:   # or iteration % 100 == 0:
+                        print(f'It: {iteration + 1}. Rec. loss: {rec_loss.item():2.4f}.')
+
+                    if (iteration + 1) % 500 == 0:
+                        if self.config['filter'] == 'none':
+                            pass
+                        elif self.config['filter'] == 'median':
+                            x_trial.data = MedianPool2d(kernel_size=3, stride=1, padding=1, same=False)(x_trial)
+                        else:
+                            raise ValueError()
+                            
+#                     grad_log['rec_loss'][iteration] = rec_loss.item()
+#                     grad_log['input_mse'][iteration] = (x_trial.detach() - x_input).pow(2).mean().item()
+                    
+                      
+
+                if dryrun:
+                    break
+        except KeyboardInterrupt:
+            print(f'Recovery interrupted manually in iteration {iteration}!')
+            pass
+        return x_trial.detach(), labels
+
+    def _gradient_closure(self, optimizer, x_trial, input_gradient, label):
+
+        def closure():
+            optimizer.zero_grad()
+            self.model.zero_grad()
+            loss = self.loss_fn(self.model(x_trial), label)
+            gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
+            rec_loss = reconstruction_costs([gradient], input_gradient, 
+                                        cost_fn=self.config['cost_fn'], indices=self.config['indices'],
+                                        norm=self.config['norm'], weights=self.config['weights'])
+
+            if self.config['total_variation'] > 0:
+                rec_loss += self.config['total_variation'] * TV(x_trial)
+            rec_loss.backward()  # second derivative --> difference between gradients
+            if self.config['signed']:
+                x_trial.grad.sign_()
+            return rec_loss
+        return closure
+
+    def _score_trial(self, x_trial, input_gradient, label):
+        if self.config['scoring_choice'] == 'loss':
+            self.model.zero_grad()
+            x_trial.grad = None
+            loss = self.loss_fn(self.model(x_trial), label)
+            gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
+            return reconstruction_costs([gradient], input_gradient,
+                                        cost_fn=self.config['cost_fn'], indices=self.config['indices'],
+                                        norm=self.config['norm'], weights=self.config['weights'])
+        elif self.config['scoring_choice'] == 'tv':
+            return TV(x_trial)
+        elif self.config['scoring_choice'] == 'inception':
+            # We do not care about diversity here!
+            return self.inception(x_trial)
+        elif self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']:
+            return 0.0
+        else:
+            raise ValueError()
+
+    def _average_trials(self, x, labels, input_data, stats):
+        print(f'Computing a combined result via {self.config["scoring_choice"]} ...')
+        if self.config['scoring_choice'] == 'pixelmedian':
+            x_optimal, _ = x.median(dim=0, keepdims=False)
+        elif self.config['scoring_choice'] == 'pixelmean':
+            x_optimal = x.mean(dim=0, keepdims=False)
+
+        self.model.zero_grad()
+        if self.reconstruct_label:
+            labels = self.model(x_optimal).softmax(dim=1)
+        loss = self.loss_fn(self.model(x_optimal), labels)
+        gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False)
+        stats['opt'] = reconstruction_costs([gradient], input_data,
+                                            cost_fn=self.config['cost_fn'],
+                                            indices=self.config['indices'],
+                                            norm=self.config['norm'], 
+                                            weights=self.config['weights'])
+        print(f'Optimal result score: {stats["opt"]:2.4f}')
+        return x_optimal, stats
+
+
+
+class FedAvgReconstructor(GradientReconstructor):
+    """Reconstruct an image from weights after n gradient descent steps."""
+
+    def __init__(self, model, mean_std=(0.0, 1.0), local_steps=2, local_lr=1e-4,
+                 config=DEFAULT_CONFIG, num_images=1, use_updates=True):
+        """Initialize with model, (mean, std) and config."""
+        super().__init__(model, mean_std, config, num_images)
+        self.local_steps = local_steps
+        self.local_lr = local_lr
+        self.use_updates = use_updates
+
+    def _gradient_closure(self, optimizer, x_trial, input_parameters, labels):
+        def closure():
+            optimizer.zero_grad()
+            self.model.zero_grad()
+            parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn,
+                                    local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates)
+            rec_loss = reconstruction_costs([gradient], input_gradient,
+                                        cost_fn=self.config['cost_fn'], indices=self.config['indices'],
+                                        norm=self.config['norm'], weights=self.config['weights'])
+
+            if self.config['total_variation'] > 0:
+                rec_loss += self.config['total_variation'] * TV(x_trial)
+            rec_loss.backward()
+            if self.config['signed']:
+                x_trial.grad.sign_()
+            return rec_loss
+        return closure
+
+    def _score_trial(self, x_trial, input_parameters, labels):
+        if self.config['scoring_choice'] == 'loss':
+            self.model.zero_grad()
+            parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn,
+                                    local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates)
+            return reconstruction_costs([gradient], input_gradient,
+                                        cost_fn=self.config['cost_fn'], indices=self.config['indices'],
+                                        norm=self.config['norm'], weights=self.config['weights'])
+        elif self.config['scoring_choice'] == 'tv':
+            return TV(x_trial)
+        elif self.config['scoring_choice'] == 'inception':
+            # We do not care about diversity here!
+            return self.inception(x_trial)
+
+
+def loss_steps(model, inputs, labels, loss_fn=torch.nn.CrossEntropyLoss(), lr=1e-4, local_steps=4, use_updates=True, batch_size=0):
+    """Take a few gradient descent steps to fit the model to the given input."""
+    patched_model = MetaMonkey(model)
+    if use_updates:
+        patched_model_origin = deepcopy(patched_model)
+    for i in range(local_steps):
+        if batch_size == 0:
+            outputs = patched_model(inputs, patched_model.parameters)
+            labels_ = labels
+        else:
+            idx = i % (inputs.shape[0] // batch_size)
+            outputs = patched_model(inputs[idx * batch_size:(idx + 1) * batch_size], patched_model.parameters)
+            labels_ = labels[idx * batch_size:(idx + 1) * batch_size]
+        loss = loss_fn(outputs, labels_).sum()
+        grad = torch.autograd.grad(loss, patched_model.parameters.values(),
+                                   retain_graph=True, create_graph=True, only_inputs=True)
+
+        patched_model.parameters = OrderedDict((name, param - lr * grad_part)
+                                               for ((name, param), grad_part)
+                                               in zip(patched_model.parameters.items(), grad))
+
+    if use_updates:
+        patched_model.parameters = OrderedDict((name, param - param_origin)
+                                               for ((name, param), (name_origin, param_origin))
+                                               in zip(patched_model.parameters.items(), patched_model_origin.parameters.items()))
+    return list(patched_model.parameters.values())
+
+def reconstruction_costs(gradients, input_gradient, cost_fn='l2', indices='def', norm='none', weights='equal'):
+    """Input gradient is given data."""
+
+    if isinstance(indices, list):
+        pass
+    elif indices == 'def':
+        indices = torch.arange(len(input_gradient))
+    elif indices == 'batch':
+        indices = torch.randperm(len(input_gradient))[:8]
+    elif indices == 'topk-1':
+        _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 4)
+    elif indices == 'top10':
+        _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 10)
+    elif indices == 'top50':
+        _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 50)
+    elif indices in ['first', 'first4']:
+        indices = torch.arange(0, 4)
+    elif indices == 'first5':
+        indices = torch.arange(0, 5)
+    elif indices == 'first10':
+        indices = torch.arange(0, 10)
+    elif indices == 'first50':
+        indices = torch.arange(0, 50)
+    elif indices == 'last5':
+        indices = torch.arange(len(input_gradient))[-5:]
+    elif indices == 'last10':
+        indices = torch.arange(len(input_gradient))[-10:]
+    elif indices == 'last50':
+        indices = torch.arange(len(input_gradient))[-50:]
+    # customise the pnorm choose
+    elif indices == 'custom_bottom': 
+        _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), norm, largest=False)
+    elif indices == 'custom_top': 
+        _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), norm, largest=True)
+    else:
+        raise ValueError()
+
+    ex = input_gradient[0]
+    if weights == 'linear':
+        weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient)
+    elif weights == 'exp':
+        weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device)
+        weights = weights.softmax(dim=0)
+        weights = weights / weights[0]
+    else:
+        # [1.,1.,1......]
+        weights = input_gradient[0].new_ones(len(input_gradient))
+
+    total_costs = 0
+    for trial_gradient in gradients:
+        pnorm = [0, 0]
+        costs = 0
+        
+        for i in indices:
+            if cost_fn == 'l2':
+                costs += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum() * weights[i]
+            elif cost_fn == 'l1':
+                costs += ((trial_gradient[i] - input_gradient[i]).abs()).sum() * weights[i]
+            elif cost_fn == 'max':
+                costs += ((trial_gradient[i] - input_gradient[i]).abs()).max() * weights[i]
+            elif cost_fn == 'sim':
+                costs -= (trial_gradient[i] * input_gradient[i]).sum() * weights[i]
+                pnorm[0] += trial_gradient[i].pow(2).sum() * weights[i]
+                pnorm[1] += input_gradient[i].pow(2).sum() * weights[i]
+                
+            elif cost_fn == 'simlocal':
+                costs += 1 - torch.nn.functional.cosine_similarity(trial_gradient[i].flatten(),
+                                                                   input_gradient[i].flatten(),
+                                                                   0, 1e-10) * weights[i]
+                
+        if cost_fn == 'sim':
+            costs = 1 + costs / pnorm[0].sqrt() / pnorm[1].sqrt()
+
+        # Accumulate final costs
+        total_costs += costs
+        
+    return total_costs / len(gradients)

+ 70 - 0
inversefed/utils.py

@@ -0,0 +1,70 @@
+"""Various utilities."""
+
+import os
+import csv
+
+import torch
+import random
+import numpy as np
+
+import socket
+import datetime
+
+
+def system_startup(args=None, defs=None):
+    """Print useful system information."""
+    # Choose GPU device and print status information:
+    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+    setup = dict(device=device, dtype=torch.float)  # non_blocking=NON_BLOCKING
+    print('Currently evaluating -------------------------------:')
+    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
+    print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.cuda.device_count()} on {socket.gethostname()}.')
+    if args is not None:
+        print(args)
+    if defs is not None:
+        print(repr(defs))
+    if torch.cuda.is_available():
+        print(f'GPU : {torch.cuda.get_device_name(device=device)}')
+    return setup
+
+def save_to_table(out_dir, name, dryrun, **kwargs):
+    """Save keys to .csv files. Function adapted from Micah."""
+    # Check for file
+    if not os.path.isdir(out_dir):
+        os.makedirs(out_dir)
+    fname = os.path.join(out_dir, f'table_{name}.csv')
+    fieldnames = list(kwargs.keys())
+
+    # Read or write header
+    try:
+        with open(fname, 'r') as f:
+            reader = csv.reader(f, delimiter='\t')
+            header = [line for line in reader][0]
+    except Exception as e:
+        print('Creating a new .csv table...')
+        with open(fname, 'w') as f:
+            writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames)
+            writer.writeheader()
+    if not dryrun:
+        # Add row for this experiment
+        with open(fname, 'a') as f:
+            writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames)
+            writer.writerow(kwargs)
+        print('\nResults saved to ' + fname + '.')
+    else:
+        print(f'Would save results to {fname}.')
+        print(f'Would save these keys: {fieldnames}.')
+
+def set_random_seed(seed=233):
+    """233 = 144 + 89 is my favorite number."""
+    torch.manual_seed(seed + 1)
+    torch.cuda.manual_seed(seed + 2)
+    torch.cuda.manual_seed_all(seed + 3)
+    np.random.seed(seed + 4)
+    torch.cuda.manual_seed_all(seed + 5)
+    random.seed(seed + 6)
+
+def set_deterministic():
+    """Switch pytorch into a deterministic computation mode."""
+    torch.backends.cudnn.deterministic = True
+    torch.backends.cudnn.benchmark = False

+ 127 - 0
main.py

@@ -0,0 +1,127 @@
+import argparse 
+import os 
+import numpy as np 
+import torch 
+import torch.nn as nn 
+from torch.utils.data import DataLoader
+from torch.autograd import grad 
+import torchvision 
+from torchvision import datasets, transforms 
+from torchvision.utils import save_image 
+import torchvision.models as models 
+import inversefed 
+from utils.dataloader import DataLoader
+from utils.stackeddata import StackedData
+
+# inverting gradients algorithm from https://github.com/JonasGeiping/invertinggradients
+
+def str2bool(v):
+    if isinstance(v, bool):
+        return v
+    if v.lower() in ('yes', 'true', 't', 'y', '1'):
+        return True
+    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+        return False
+    else:
+        raise argparse.ArgumentTypeError('Boolean value expected.')
+    
+parser = argparse.ArgumentParser(description='Adversarial attack from gradient leakage')
+parser.add_argument('--model', type=str, help='model to perform adversarial attack')
+parser.add_argument('--data', type=str, help='dataset used')
+parser.add_argument('--stack_size', default=4, type=int, help='size use to stack images')
+parser.add_argument('-l','--target_idx', nargs='+', help='list of data index to recontruct')
+parser.add_argument('--save', type=str2bool, nargs='?', const=False, default=True, help='save')
+parser.add_argument('--gpu', type=str2bool, nargs='?', const=False, default=True, help='use gpu')
+
+
+args = parser.parse_args()
+model_name = args.model
+data = args.data
+stack_size = args.stack_size
+save_output = args.save 
+if args.target_idx is not None: 
+    target_idx = [int(i) for i in args.target_idx]
+else: 
+    target_idx = args.target_idx
+
+device = 'cpu'
+if args.gpu: 
+    device = 'cuda'
+print("Running on %s" % device)
+
+
+def val_model(dataset, model, criterion):
+    # evaluate trained model, record wrongly predicted index
+    model.eval() 
+    # record wrong pred index
+    index_ls = [] 
+    with torch.no_grad(): 
+        val_loss, val_corrects = 0, 0 
+        for batch_idx, (inputs, labels) in enumerate(dataset): 
+            inputs = inputs.unsqueeze(dim=0).to(device)
+            labels = torch.as_tensor([labels]).to(device)
+            outputs = model(inputs)
+            loss = criterion(outputs, labels)
+            _, preds = torch.max(outputs, 1)
+            val_loss += loss.item() * inputs.size(0) # mutiply by number of batches
+            val_corrects += torch.sum(preds == labels.data)
+            if (preds != labels.data): 
+                index_ls.append(batch_idx)
+            if batch_idx == 100:
+                break
+
+        total_loss = val_loss / len(dataset) 
+        total_acc = val_corrects.double() / len(dataset)
+        print('{} Loss: {:.4f} Acc: {:.4f}'.format('val', total_loss, total_acc))
+        return index_ls
+
+
+dataloader = DataLoader(data, device)
+dataset, data_shape, classes, (dm, ds) = dataloader.get_data_info() 
+model = models.resnet18(pretrained=True) # use pretrained model from torchvision
+model.fc = nn.Linear(512, len(classes)) # reinitialize model output: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
+model = model.to(device)
+model.eval()
+criterion = nn.CrossEntropyLoss() 
+
+stack_data = StackedData(stack_size=4, model_name=model_name, dataset_name=data, dataset=dataset, save_output=save_output, device=device)
+
+if target_idx is None:
+    wrong_pred_idx = val_model(dataset, model, criterion)
+else:
+    if isinstance(target_idx, (list))==False:
+        wrong_pred_idx = [target_idx]
+    else: 
+        wrong_pred_idx = target_idx
+    
+
+stacked_data_d = stack_data.create_stacked_data(wrong_pred_idx)
+for i in range(len(stacked_data_d['gt_img'])): 
+    gt_img, gt_label, img_idx = stacked_data_d['gt_img'][i], stacked_data_d['gt_label'][i], stacked_data_d['img_index'][i]
+    stack_pred = model(gt_img)
+    target_loss = criterion(stack_pred, gt_label)
+    input_grad = grad(target_loss, model.parameters())
+    input_grad =[grad.detach() for grad in input_grad]
+    # default configuration from inversefed
+    config = dict(signed=True,
+              boxed=False,
+              cost_fn='sim',
+              indices='def',
+              norm='none',
+              weights='equal',
+              lr=0.1, 
+              optim='adam',
+              restarts=1,
+              max_iterations=200,
+              total_variation=0.1,
+              init='randn',
+              filter='none',
+              lr_decay=True,
+              scoring_choice='loss')
+    
+    rec_machine = inversefed.GradientReconstructor(model, (dm, ds), config, num_images=gt_img.shape[0])
+    results = rec_machine.reconstruct(input_grad, gt_label, gt_img ,img_shape=data_shape)
+    output_img, stats = results
+    rec_pred = model(output_img)
+    print('Predictions for recontructed images: ', [classes[l] for l in torch.max(rec_pred, axis=1)[1]])
+    stack_data.grid_plot(img_idx, output_img, rec_pred, dm, ds)

BIN
output_rec_images/resnet18_cifar10_4_770_bird_0.png


BIN
output_rec_images/resnet18_cifar10_4_770_bird_1.png


BIN
output_rec_images/resnet18_cifar10_4_770_bird_2.png


BIN
output_rec_images/resnet18_cifar10_4_770_bird_3.png


+ 0 - 0
utils/__init__.py


+ 48 - 0
utils/dataloader.py

@@ -0,0 +1,48 @@
+from inversefed import consts
+import torch
+from torchvision import datasets, transforms 
+
+class DataLoader: 
+    def __init__(self, data, device): 
+        self.data = data 
+        self.device = device
+        
+    def get_mean_std(self): 
+        if self.data == 'cifar10': 
+            mean, std = consts.cifar10_mean, consts.cifar10_std 
+        elif self.data ==  'cifar100': 
+            mean, std = consts.cifar100_mean, consts.cifar100_std 
+        elif self.data == 'mnist': 
+            mean, std = consts.mnist_mean, consts.mnist_std 
+        elif self.data == 'imagenet':
+            mean, std = consts.imagenet_mean, consts.imagenet_std 
+        else: 
+            raise Exception("dataset not found")
+        return mean, std
+
+    def get_data_info(self):
+        mean, std = self.get_mean_std()
+        transform = transforms.Compose([transforms.ToTensor(),
+                                        transforms.Normalize(mean, std)])
+
+        dm = torch.as_tensor(mean)[:, None, None].to(self.device)
+        ds = torch.as_tensor(std)[:, None, None].to(self.device)
+        data_root = 'data/cifar_data'
+#         data_root = '~/.torch'
+        if self.data == 'cifar10': 
+            dataset = datasets.CIFAR10(root=data_root, download=True, train=False, transform=transform)
+        elif self.data ==  'cifar100': 
+            dataset = datasets.CIFAR100(root=data_root, download=True, train=False, transform=transform)
+        elif self.data == 'mnist': 
+            dataset = datasets.MNIST(root=data_root, download=True, train=False, transform=transform)
+        elif self.data == 'imagenet':
+            dataset = datasets.ImageNet(root=data_root, download=True, train=False, transform=transform)
+        else: 
+            raise Exception("dataset not found, load your own datasets")
+
+        data_shape = dataset[0][0].shape 
+        classes = dataset.classes 
+
+        return dataset, data_shape, classes, (dm, ds)
+    
+

+ 75 - 0
utils/stackeddata.py

@@ -0,0 +1,75 @@
+import os 
+import numpy as np 
+import matplotlib.pyplot as plt 
+import json
+import torch
+from inversefed import consts
+from torchvision import datasets, transforms 
+from torchvision.utils import save_image
+
+class StackedData: 
+    def __init__(self, stack_size, model_name, dataset_name, dataset, save_output, device): 
+        self.stack_size = stack_size
+        self.model_name = model_name 
+        self.dataset_name = dataset_name 
+        self.dataset = dataset
+        self.save_output = save_output 
+        self.device = device 
+        
+    def create_stacked_data(self, index_ls): 
+        batch_data = {'gt_img': [], 'gt_label': [], 'img_index': []} 
+        for index in index_ls: 
+            gt_img, gt_label = self.dataset[index]
+            gt_images, gt_labels = [], [] 
+            for i in range(self.stack_size): 
+                gt_images.append(gt_img) 
+                gt_labels.append(torch.as_tensor((gt_label,), device=self.device))
+
+            gt_images_ = torch.stack(gt_images).to(self.device)
+            gt_labels_ = torch.cat(gt_labels)
+            batch_data['gt_img'].append(gt_images_)
+            batch_data['gt_label'].append(gt_labels_)
+            batch_data['img_index'].append(index)
+
+        return batch_data
+
+    def grid_plot(self, img_idx, tensors, logit, dm, ds):
+        _, indices = torch.max(logit, 1)
+        accuracy, _ =  torch.max(torch.softmax(logit, dim=1).cpu(), dim=1)
+        labels = list(zip(indices.cpu().numpy(), list(np.around(accuracy.detach().numpy(),4))))
+        
+        # un-normalize before plotting 
+        tensors = tensors.clone().detach()
+        tensors.mul_(ds).add_(dm).clamp_(0, 1)
+
+        if self.save_output: 
+            if os.path.exists('output_rec_images')==False: 
+                os.makedirs('output_rec_images')
+
+            if self.model_name is None: 
+                saved_name = '{}_{}_{}'.format(self.dataset_name, self.stack_size, img_idx)
+            else: 
+                saved_name = '{}_{}_{}_{}'.format(self.model_name, self.dataset_name, self.stack_size, img_idx)
+
+            for i, tensor in enumerate(tensors): 
+                extension = '.png'
+                saved_name_ = "{}_{}_{}{}".format(saved_name, self.dataset.classes[indices[i]], i, extension) 
+                save_image(tensor, os.path.join('output_rec_images', saved_name_))
+
+        if tensors.shape[0]==1: 
+            tensors = tensors[0]
+            plt.figure(figsize=(4,4))
+            plt.imshow(tensors.permute(1,2,0).cpu())
+            plt.title(self.dataset[labels])
+
+        else: 
+            grid_width = int(np.ceil(len(labels)**0.5))
+            grid_height = int(np.ceil(len(labels) / grid_width))
+
+            fig, axes = plt.subplots(grid_height, grid_width, figsize=(3, 3))
+            for im, l, ax in zip(tensors, labels, axes.flatten()):
+                ax.imshow(im.permute(1, 2, 0).cpu());
+                ax.set_title(l)
+                ax.axis('off')
+        plt.show() 
+    

Một số tệp đã không được hiển thị bởi vì quá nhiều tập tin thay đổi trong này khác