dataset_util.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from PIL import Image
  2. from torch.utils.data import Dataset
  3. class ImageDataset(Dataset):
  4. def __init__(self, images, labels, transform_x=None, transform_y=None):
  5. self.images = images
  6. self.labels = labels
  7. self.transform_x = transform_x
  8. self.transform_y = transform_y
  9. def __len__(self):
  10. return len(self.labels)
  11. def __getitem__(self, index):
  12. data, label = self.images[index], self.labels[index]
  13. if self.transform_x is not None:
  14. data = self.transform_x(Image.open(data))
  15. else:
  16. data = Image.open(data)
  17. if self.transform_y is not None:
  18. label = self.transform_y(label)
  19. return data, label
  20. class TransformDataset(Dataset):
  21. def __init__(self, images, labels, transform_x=None, transform_y=None):
  22. self.data = images
  23. self.targets = labels
  24. self.transform_x = transform_x
  25. self.transform_y = transform_y
  26. def __len__(self):
  27. return len(self.data)
  28. def __getitem__(self, idx):
  29. sample = self.data[idx]
  30. target = self.targets[idx]
  31. if self.transform_x:
  32. sample = self.transform_x(sample)
  33. if self.transform_y:
  34. target = self.transform_y(target)
  35. return sample, target