123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- from PIL import Image
- from torch.utils.data import Dataset
- class ImageDataset(Dataset):
- def __init__(self, images, labels, transform_x=None, transform_y=None):
- self.images = images
- self.labels = labels
- self.transform_x = transform_x
- self.transform_y = transform_y
- def __len__(self):
- return len(self.labels)
- def __getitem__(self, index):
- data, label = self.images[index], self.labels[index]
- if self.transform_x is not None:
- data = self.transform_x(Image.open(data))
- else:
- data = Image.open(data)
- if self.transform_y is not None:
- label = self.transform_y(label)
- return data, label
- class TransformDataset(Dataset):
- def __init__(self, images, labels, transform_x=None, transform_y=None):
- self.data = images
- self.targets = labels
- self.transform_x = transform_x
- self.transform_y = transform_y
- def __len__(self):
- return len(self.data)
- def __getitem__(self, idx):
- sample = self.data[idx]
- target = self.targets[idx]
- if self.transform_x:
- sample = self.transform_x(sample)
- if self.transform_y:
- target = self.transform_y(target)
- return sample, target
|