image.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import torch
  2. from federatedml.nn.dataset.base import Dataset
  3. from torchvision.datasets import ImageFolder
  4. from torchvision import transforms
  5. import numpy as np
  6. class ImageDataset(Dataset):
  7. """
  8. A basic Image Dataset built on pytorch ImageFolder, supports simple image transform
  9. Given a folder path, ImageDataset will load images from this folder, images in this
  10. folder need to be organized in a Torch-ImageFolder format, see
  11. https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html for details.
  12. Image name will be automatically taken as the sample id.
  13. Parameters
  14. ----------
  15. center_crop : bool, use center crop transformer
  16. center_crop_shape: tuple or list
  17. generate_id_from_file_name: bool, whether to take image name as sample id
  18. file_suffix: str, default is '.jpg', if generate_id_from_file_name is True, will remove this suffix from file name,
  19. result will be the sample id
  20. return_label: bool, return label or not, this option is for host dataset, when running hetero-NN
  21. float64: bool, returned image tensors will be transformed to double precision
  22. label_dtype: str, long, float, or double, the dtype of return label
  23. """
  24. def __init__(self, center_crop=False, center_crop_shape=None,
  25. generate_id_from_file_name=True, file_suffix='.jpg',
  26. return_label=True, float64=False, label_dtype='long'):
  27. super(ImageDataset, self).__init__()
  28. self.image_folder: ImageFolder = None
  29. self.center_crop = center_crop
  30. self.size = center_crop_shape
  31. self.return_label = return_label
  32. self.generate_id_from_file_name = generate_id_from_file_name
  33. self.file_suffix = file_suffix
  34. self.float64 = float64
  35. self.dtype = torch.float32 if not self.float64 else torch.float64
  36. avail_label_type = ['float', 'long', 'double']
  37. self.sample_ids = None
  38. assert label_dtype in avail_label_type, 'available label dtype : {}'.format(
  39. avail_label_type)
  40. if label_dtype == 'double':
  41. self.label_dtype = torch.float64
  42. elif label_dtype == 'long':
  43. self.label_dtype = torch.int64
  44. else:
  45. self.label_dtype = torch.float32
  46. def load(self, folder_path):
  47. # read image from folders
  48. if self.center_crop:
  49. transformer = transforms.Compose(
  50. [transforms.CenterCrop(size=self.size), transforms.ToTensor()])
  51. else:
  52. transformer = transforms.Compose([transforms.ToTensor()])
  53. if folder_path.endswith('/'):
  54. folder_path = folder_path[: -1]
  55. image_folder_path = folder_path
  56. folder = ImageFolder(root=image_folder_path, transform=transformer)
  57. self.image_folder = folder
  58. if self.generate_id_from_file_name:
  59. # use image name as its sample id
  60. file_name = self.image_folder.imgs
  61. ids = []
  62. for name in file_name:
  63. sample_id = name[0].split(
  64. '/')[-1].replace(self.file_suffix, '')
  65. ids.append(sample_id)
  66. self.sample_ids = ids
  67. def __getitem__(self, item):
  68. if self.return_label:
  69. item = self.image_folder[item]
  70. return item[0].type(
  71. self.dtype), torch.tensor(
  72. item[1]).type(
  73. self.label_dtype)
  74. else:
  75. return self.image_folder[item][0].type(self.dtype)
  76. def __len__(self):
  77. return len(self.image_folder)
  78. def __repr__(self):
  79. return self.image_folder.__repr__()
  80. def get_classes(self):
  81. return np.unique(self.image_folder.targets).tolist()
  82. def get_sample_ids(self):
  83. return self.sample_ids
  84. if __name__ == '__main__':
  85. pass