|
- import os
- import os.path
- import random
- import warnings
- import numpy as np
- import torch
- import torch.utils.data as data
- import torchvision.transforms as transforms
- from PIL import Image, ImageOps, ImageFile
- ImageFile.LOAD_TRUNCATED_IMAGES = True
- from easyfl.datasets import FederatedTorchDataset
- DEFAULT_TASKS = ['depth_zbuffer', 'normal', 'segment_semantic', 'edge_occlusion', 'reshading', 'keypoints2d', 'edge_texture']
- VAL_LIMIT = 100
- TEST_LIMIT = (1000, 2000)
- def get_dataset(data_dir, train_client_file, test_client_file, tasks, image_size, model_limit=None, half_sized_output=False, augment=False):
- dataset = {}
- client_ids = set()
- with open(train_client_file) as f:
- for line in f:
- client_id = line.strip()
- client_ids.add(client_id)
- dataset[client_id] = TaskonomyLoader(data_dir,
- label_set=tasks,
- model_whitelist=[client_id],
- model_limit=model_limit,
- output_size=(image_size, image_size),
- half_sized_output=half_sized_output,
- augment=augment)
- print(f'Client {client_id}: {len(dataset[client_id])} instances.')
- train_set = FederatedTorchDataset(dataset, client_ids)
- if augment == "aggressive":
- print('Data augmentation is on (aggressive).')
- elif augment:
- print('Data augmentation is on (flip).')
- else:
- print('no data augmentation')
- test_client_ids = set()
- with open(test_client_file) as f:
- for line in f:
- test_client_ids.add(line.strip())
- val_set = get_validation_data(data_dir, test_client_ids, tasks, image_size, VAL_LIMIT, half_sized_output)
- test_set = get_validation_data(data_dir, test_client_ids, tasks, image_size, TEST_LIMIT, half_sized_output)
- return train_set, val_set, test_set
- def get_validation_data(data_dir, client_ids, tasks, image_size, model_limit, half_sized_output=False):
- dataset = TaskonomyLoader(data_dir,
- label_set=tasks,
- model_whitelist=client_ids,
- model_limit=model_limit,
- output_size=(image_size, image_size),
- half_sized_output=half_sized_output,
- augment=False)
- if model_limit == VAL_LIMIT:
- print(f'Found {len(dataset)} validation instances.')
- else:
- print(f'Found {len(dataset)} test instances.')
- return FederatedTorchDataset(dataset, client_ids)
- class TaskonomyLoader(data.Dataset):
- def __init__(self,
- root,
- label_set=DEFAULT_TASKS,
- model_whitelist=None,
- model_limit=None,
- output_size=None,
- convert_to_tensor=True,
- return_filename=False,
- half_sized_output=False,
- augment=False):
- self.root = root
- self.model_limit = model_limit
- self.records = []
- if model_whitelist is None:
- self.model_whitelist = None
- elif type(model_whitelist) is str:
- self.model_whitelist = set()
- with open(model_whitelist) as f:
- for line in f:
- self.model_whitelist.add(line.strip())
- else:
- self.model_whitelist = model_whitelist
- for i, (where, subdirs, files) in enumerate(os.walk(os.path.join(root, 'rgb'))):
- if subdirs:
- continue
- model = where.split('/')[-1]
- if self.model_whitelist is None or model in self.model_whitelist:
- full_paths = [os.path.join(where, f) for f in files]
- if isinstance(model_limit, tuple):
- full_paths.sort()
- full_paths = full_paths[model_limit[0]:model_limit[1]]
- elif model_limit is not None:
- full_paths.sort()
- full_paths = full_paths[:model_limit]
- self.records += full_paths
-
- self.label_set = label_set
- self.output_size = output_size
- self.half_sized_output = half_sized_output
- self.convert_to_tensor = convert_to_tensor
- self.return_filename = return_filename
- self.to_tensor = transforms.ToTensor()
- self.augment = augment
- self.last = {}
- def process_image(self, im, input=False):
- output_size = self.output_size
- if self.half_sized_output and not input:
- if output_size is None:
- output_size = (128, 128)
- else:
- output_size = output_size[0] // 2, output_size[1] // 2
- if output_size is not None and output_size != im.size:
- im = im.resize(output_size, Image.BILINEAR)
- bands = im.getbands()
- if self.convert_to_tensor:
- if bands[0] == 'L':
- im = np.array(im)
- im.setflags(write=1)
- im = torch.from_numpy(im).unsqueeze(0)
- else:
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- im = self.to_tensor(im)
- return im
- def __getitem__(self, index):
- """
- Args:
- index (int): Index
- Returns:
- tuple: (image, target) where target is an uint8 matrix of integers with the same width and height.
- If there is an error loading an image or its labels, simply return the previous example.
- """
- with torch.no_grad():
- file_name = self.records[index]
- save_filename = file_name
- flip_lr = (random.randint(0, 1) > .5 and self.augment)
- flip_ud = (random.randint(0, 1) > .5 and (self.augment == "aggressive"))
- pil_im = Image.open(file_name)
- if flip_lr:
- pil_im = ImageOps.mirror(pil_im)
- if flip_ud:
- pil_im = ImageOps.flip(pil_im)
- im = self.process_image(pil_im, input=True)
- error = False
- ys = {}
- mask = None
- to_load = self.label_set
- if len(set(['edge_occlusion', 'normal', 'reshading', 'principal_curvature']).intersection(
- self.label_set)) != 0:
- if os.path.isfile(file_name.replace('rgb', 'mask')):
- to_load.append('mask')
- elif 'depth_zbuffer' not in to_load:
- to_load.append('depth_zbuffer')
- for i in to_load:
- if i == 'mask' and mask is not None:
- continue
- yfilename = file_name.replace('rgb', i)
- try:
- yim = Image.open(yfilename)
- except:
- yim = self.last[i].copy()
- error = True
- if (i in self.last and yim.getbands() != self.last[i].getbands()) or error:
- yim = self.last[i].copy()
- try:
- self.last[i] = yim.copy()
- except:
- pass
- if flip_lr:
- try:
- yim = ImageOps.mirror(yim)
- except:
- pass
- if flip_ud:
- try:
- yim = ImageOps.flip(yim)
- except:
- pass
- try:
- yim = self.process_image(yim)
- except:
- yim = self.last[i].copy()
- yim = self.process_image(yim)
- if i == 'depth_zbuffer':
- yim = yim.float()
- mask = yim < (2 ** 13)
- yim -= 1500.0
- yim /= 1000.0
- elif i == 'edge_occlusion':
- yim = yim.float()
- yim -= 56.0248
- yim /= 239.1265
- elif i == 'keypoints2d':
- yim = yim.float()
- yim -= 50.0
- yim /= 100.0
- elif i == 'edge_texture':
- yim = yim.float()
- yim -= 718.0
- yim /= 1070.0
- elif i == 'normal':
- yim = yim.float()
- yim -= .5
- yim *= 2.0
- if flip_lr:
- yim[0] *= -1.0
- if flip_ud:
- yim[1] *= -1.0
- elif i == 'reshading':
- yim = yim.mean(dim=0, keepdim=True)
- yim -= .4962
- yim /= 0.2846
-
- elif i == 'principal_curvature':
- yim = yim[:2]
- yim -= torch.tensor([0.5175, 0.4987]).view(2, 1, 1)
- yim /= torch.tensor([0.1373, 0.0359]).view(2, 1, 1)
-
- elif i == 'mask':
- mask = yim.bool()
- yim = mask
- ys[i] = yim
- if mask is not None:
- ys['mask'] = mask
- if not 'rgb' in self.label_set:
- ys['rgb'] = im
- if self.return_filename:
- return im, ys, file_name
- else:
- return im, ys
- def __len__(self):
- return len(self.records)
- class DataPrefetcher:
- def __init__(self, loader, device):
- self.inital_loader = loader
- self.device = device
- self.loader = iter(loader)
- self.stream = torch.cuda.Stream()
- self.preload()
- def preload(self):
- try:
- self.next_input, self.next_target = next(self.loader)
- except StopIteration:
-
-
- self.loader = iter(self.inital_loader)
- self.preload()
- return
- with torch.cuda.stream(self.stream):
- self.next_input = self.next_input.to(self.device, non_blocking=True)
-
- self.next_target = {key: val.to(self.device, non_blocking=True) for (key, val) in self.next_target.items()}
- def next(self):
- torch.cuda.current_stream().wait_stream(self.stream)
- input = self.next_input
- target = self.next_target
- self.preload()
- return input, target
- def update_device(self, device):
- self.device = device
- def show(im, ys):
- from matplotlib import pyplot as plt
- plt.figure(figsize=(30, 30))
- plt.subplot(4, 3, 1).set_title('RGB')
- im = im.permute([1, 2, 0])
- plt.imshow(im)
- for i, y in enumerate(ys):
- yim = ys[y]
- plt.subplot(4, 3, 2 + i).set_title(y)
- if y == 'normal':
- yim += 1
- yim /= 2
- if yim.shape[0] == 2:
- yim = torch.cat([yim, torch.zeros((1, yim.shape[1], yim.shape[2]))], dim=0)
- yim = yim.permute([1, 2, 0])
- yim = yim.squeeze()
- plt.imshow(np.array(yim))
- plt.show()
- def test():
- loader = TaskonomyLoader(
- '/home/tstand/Desktop/lite_taskonomy/',
- label_set=['normal', 'reshading', 'principal_curvature', 'edge_occlusion', 'depth_zbuffer'],
- augment='aggressive')
- totals = {}
- totals2 = {}
- count = {}
- indices = list(range(len(loader)))
- random.shuffle(indices)
- for data_count, index in enumerate(indices):
- im, ys = loader[index]
- show(im, ys)
- mask = ys['mask']
-
- print(index)
- for i, y in enumerate(ys):
- yim = ys[y]
- yim = yim.float()
- if y not in totals:
- totals[y] = 0
- totals2[y] = 0
- count[y] = 0
- totals[y] += (yim * mask).sum(dim=[1, 2])
- totals2[y] += ((yim ** 2) * mask).sum(dim=[1, 2])
- count[y] += (torch.ones_like(yim) * mask).sum(dim=[1, 2])
-
- std = torch.sqrt((totals2[y] - (totals[y] ** 2) / count[y]) / count[y])
- print(data_count, '/', len(loader), y, 'mean:', totals[y] / count[y], 'std:', std)
- def output_mask(index, loader):
- filename = loader.records[index]
- filename = filename.replace('rgb', 'mask')
- filename = filename.replace('/intel_nvme/taskonomy_data/', '/run/shm/')
- if os.path.isfile(filename):
- return
- print(filename)
- x, ys = loader[index]
- mask = ys['mask']
- mask = mask.squeeze()
- mask_im = Image.fromarray(mask.numpy())
- mask_im = mask_im.convert(mode='1')
-
-
-
-
-
- path, _ = os.path.split(filename)
- os.makedirs(path, exist_ok=True)
- mask_im.save(filename, bits=1, optimize=True)
- def get_masks():
- loader = TaskonomyLoader(
- '/intel_nvme/taskonomy_data/',
- label_set=['depth_zbuffer'],
- augment=False)
- indices = list(range(len(loader)))
- random.shuffle(indices)
- for count, index in enumerate(indices):
- print(count, len(indices))
- output_mask(index, loader)
- if __name__ == "__main__":
- file_name = "/Users/weiming/personal-projects/taskonomy_dataset/rgb/cosmos/point_512_view_7_domain_rgb.png"
- pil_im = Image.open(file_name)
- pil_im = ImageOps.mirror(pil_im)
- output_size = (128, 128)
- pil_im = pil_im.resize(output_size, Image.BILINEAR)
- print(pil_im)
- print("Completed")
|