stackeddata.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import os
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import json
  5. import torch
  6. from inversefed import consts
  7. from torchvision import datasets, transforms
  8. from torchvision.utils import save_image
  9. class StackedData:
  10. def __init__(self, stack_size, model_name, dataset_name, dataset, save_output, device):
  11. self.stack_size = stack_size
  12. self.model_name = model_name
  13. self.dataset_name = dataset_name
  14. self.dataset = dataset
  15. self.save_output = save_output
  16. self.device = device
  17. def create_stacked_data(self, index_ls):
  18. batch_data = {'gt_img': [], 'gt_label': [], 'img_index': []}
  19. for index in index_ls:
  20. gt_img, gt_label = self.dataset[index]
  21. gt_images, gt_labels = [], []
  22. for i in range(self.stack_size):
  23. gt_images.append(gt_img)
  24. gt_labels.append(torch.as_tensor((gt_label,), device=self.device))
  25. gt_images_ = torch.stack(gt_images).to(self.device)
  26. gt_labels_ = torch.cat(gt_labels)
  27. batch_data['gt_img'].append(gt_images_)
  28. batch_data['gt_label'].append(gt_labels_)
  29. batch_data['img_index'].append(index)
  30. return batch_data
  31. def grid_plot(self, img_idx, tensors, logit, dm, ds):
  32. _, indices = torch.max(logit, 1)
  33. accuracy, _ = torch.max(torch.softmax(logit, dim=1).cpu(), dim=1)
  34. labels = list(zip(indices.cpu().numpy(), list(np.around(accuracy.detach().numpy(),4))))
  35. # un-normalize before plotting
  36. tensors = tensors.clone().detach()
  37. tensors.mul_(ds).add_(dm).clamp_(0, 1)
  38. if self.save_output:
  39. if os.path.exists('output_rec_images')==False:
  40. os.makedirs('output_rec_images')
  41. if self.model_name is None:
  42. saved_name = '{}_{}_{}'.format(self.dataset_name, self.stack_size, img_idx)
  43. else:
  44. saved_name = '{}_{}_{}_{}'.format(self.model_name, self.dataset_name, self.stack_size, img_idx)
  45. for i, tensor in enumerate(tensors):
  46. extension = '.png'
  47. saved_name_ = "{}_{}_{}{}".format(saved_name, self.dataset.classes[indices[i]], i, extension)
  48. save_image(tensor, os.path.join('output_rec_images', saved_name_))
  49. if tensors.shape[0]==1:
  50. tensors = tensors[0]
  51. plt.figure(figsize=(4,4))
  52. plt.imshow(tensors.permute(1,2,0).cpu())
  53. plt.title(self.dataset[labels])
  54. else:
  55. grid_width = int(np.ceil(len(labels)**0.5))
  56. grid_height = int(np.ceil(len(labels) / grid_width))
  57. fig, axes = plt.subplots(grid_height, grid_width, figsize=(3, 3))
  58. for im, l, ax in zip(tensors, labels, axes.flatten()):
  59. ax.imshow(im.permute(1, 2, 0).cpu());
  60. ax.set_title(l)
  61. ax.axis('off')
  62. plt.show()