dataset.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. import os
  2. import os.path
  3. import random
  4. import warnings
  5. import numpy as np
  6. import torch
  7. import torch.utils.data as data
  8. import torchvision.transforms as transforms
  9. from PIL import Image, ImageOps, ImageFile
  10. ImageFile.LOAD_TRUNCATED_IMAGES = True
  11. from easyfl.datasets import FederatedTorchDataset
  12. DEFAULT_TASKS = ['depth_zbuffer', 'normal', 'segment_semantic', 'edge_occlusion', 'reshading', 'keypoints2d', 'edge_texture']
  13. VAL_LIMIT = 100
  14. TEST_LIMIT = (1000, 2000)
  15. def get_dataset(data_dir, train_client_file, test_client_file, tasks, image_size, model_limit=None, half_sized_output=False, augment=False):
  16. dataset = {} # each building in taskonomy dataset is a client
  17. client_ids = set()
  18. with open(train_client_file) as f:
  19. for line in f:
  20. client_id = line.strip()
  21. client_ids.add(client_id)
  22. dataset[client_id] = TaskonomyLoader(data_dir,
  23. label_set=tasks,
  24. model_whitelist=[client_id],
  25. model_limit=model_limit,
  26. output_size=(image_size, image_size),
  27. half_sized_output=half_sized_output,
  28. augment=augment)
  29. print(f'Client {client_id}: {len(dataset[client_id])} instances.')
  30. train_set = FederatedTorchDataset(dataset, client_ids)
  31. if augment == "aggressive":
  32. print('Data augmentation is on (aggressive).')
  33. elif augment:
  34. print('Data augmentation is on (flip).')
  35. else:
  36. print('no data augmentation')
  37. test_client_ids = set()
  38. with open(test_client_file) as f:
  39. for line in f:
  40. test_client_ids.add(line.strip())
  41. val_set = get_validation_data(data_dir, test_client_ids, tasks, image_size, VAL_LIMIT, half_sized_output)
  42. test_set = get_validation_data(data_dir, test_client_ids, tasks, image_size, TEST_LIMIT, half_sized_output)
  43. return train_set, val_set, test_set
  44. def get_validation_data(data_dir, client_ids, tasks, image_size, model_limit, half_sized_output=False):
  45. dataset = TaskonomyLoader(data_dir,
  46. label_set=tasks,
  47. model_whitelist=client_ids,
  48. model_limit=model_limit,
  49. output_size=(image_size, image_size),
  50. half_sized_output=half_sized_output,
  51. augment=False)
  52. if model_limit == VAL_LIMIT:
  53. print(f'Found {len(dataset)} validation instances.')
  54. else:
  55. print(f'Found {len(dataset)} test instances.')
  56. return FederatedTorchDataset(dataset, client_ids)
  57. class TaskonomyLoader(data.Dataset):
  58. def __init__(self,
  59. root,
  60. label_set=DEFAULT_TASKS,
  61. model_whitelist=None,
  62. model_limit=None,
  63. output_size=None,
  64. convert_to_tensor=True,
  65. return_filename=False,
  66. half_sized_output=False,
  67. augment=False):
  68. self.root = root
  69. self.model_limit = model_limit
  70. self.records = []
  71. if model_whitelist is None:
  72. self.model_whitelist = None
  73. elif type(model_whitelist) is str:
  74. self.model_whitelist = set()
  75. with open(model_whitelist) as f:
  76. for line in f:
  77. self.model_whitelist.add(line.strip())
  78. else:
  79. self.model_whitelist = model_whitelist
  80. for i, (where, subdirs, files) in enumerate(os.walk(os.path.join(root, 'rgb'))):
  81. if subdirs:
  82. continue
  83. model = where.split('/')[-1]
  84. if self.model_whitelist is None or model in self.model_whitelist:
  85. full_paths = [os.path.join(where, f) for f in files]
  86. if isinstance(model_limit, tuple):
  87. full_paths.sort()
  88. full_paths = full_paths[model_limit[0]:model_limit[1]]
  89. elif model_limit is not None:
  90. full_paths.sort()
  91. full_paths = full_paths[:model_limit]
  92. self.records += full_paths
  93. # self.records = manager.list(self.records)
  94. self.label_set = label_set
  95. self.output_size = output_size
  96. self.half_sized_output = half_sized_output
  97. self.convert_to_tensor = convert_to_tensor
  98. self.return_filename = return_filename
  99. self.to_tensor = transforms.ToTensor()
  100. self.augment = augment
  101. self.last = {}
  102. def process_image(self, im, input=False):
  103. output_size = self.output_size
  104. if self.half_sized_output and not input:
  105. if output_size is None:
  106. output_size = (128, 128)
  107. else:
  108. output_size = output_size[0] // 2, output_size[1] // 2
  109. if output_size is not None and output_size != im.size:
  110. im = im.resize(output_size, Image.BILINEAR)
  111. bands = im.getbands()
  112. if self.convert_to_tensor:
  113. if bands[0] == 'L':
  114. im = np.array(im)
  115. im.setflags(write=1)
  116. im = torch.from_numpy(im).unsqueeze(0)
  117. else:
  118. with warnings.catch_warnings():
  119. warnings.simplefilter("ignore")
  120. im = self.to_tensor(im)
  121. return im
  122. def __getitem__(self, index):
  123. """
  124. Args:
  125. index (int): Index
  126. Returns:
  127. tuple: (image, target) where target is an uint8 matrix of integers with the same width and height.
  128. If there is an error loading an image or its labels, simply return the previous example.
  129. """
  130. with torch.no_grad():
  131. file_name = self.records[index]
  132. save_filename = file_name
  133. flip_lr = (random.randint(0, 1) > .5 and self.augment)
  134. flip_ud = (random.randint(0, 1) > .5 and (self.augment == "aggressive"))
  135. pil_im = Image.open(file_name)
  136. if flip_lr:
  137. pil_im = ImageOps.mirror(pil_im)
  138. if flip_ud:
  139. pil_im = ImageOps.flip(pil_im)
  140. im = self.process_image(pil_im, input=True)
  141. error = False
  142. ys = {}
  143. mask = None
  144. to_load = self.label_set
  145. if len(set(['edge_occlusion', 'normal', 'reshading', 'principal_curvature']).intersection(
  146. self.label_set)) != 0:
  147. if os.path.isfile(file_name.replace('rgb', 'mask')):
  148. to_load.append('mask')
  149. elif 'depth_zbuffer' not in to_load:
  150. to_load.append('depth_zbuffer')
  151. for i in to_load:
  152. if i == 'mask' and mask is not None:
  153. continue
  154. yfilename = file_name.replace('rgb', i)
  155. try:
  156. yim = Image.open(yfilename)
  157. except:
  158. yim = self.last[i].copy()
  159. error = True
  160. if (i in self.last and yim.getbands() != self.last[i].getbands()) or error:
  161. yim = self.last[i].copy()
  162. try:
  163. self.last[i] = yim.copy()
  164. except:
  165. pass
  166. if flip_lr:
  167. try:
  168. yim = ImageOps.mirror(yim)
  169. except:
  170. pass
  171. if flip_ud:
  172. try:
  173. yim = ImageOps.flip(yim)
  174. except:
  175. pass
  176. try:
  177. yim = self.process_image(yim)
  178. except:
  179. yim = self.last[i].copy()
  180. yim = self.process_image(yim)
  181. if i == 'depth_zbuffer':
  182. yim = yim.float()
  183. mask = yim < (2 ** 13)
  184. yim -= 1500.0
  185. yim /= 1000.0
  186. elif i == 'edge_occlusion':
  187. yim = yim.float()
  188. yim -= 56.0248
  189. yim /= 239.1265
  190. elif i == 'keypoints2d':
  191. yim = yim.float()
  192. yim -= 50.0
  193. yim /= 100.0
  194. elif i == 'edge_texture':
  195. yim = yim.float()
  196. yim -= 718.0
  197. yim /= 1070.0
  198. elif i == 'normal':
  199. yim = yim.float()
  200. yim -= .5
  201. yim *= 2.0
  202. if flip_lr:
  203. yim[0] *= -1.0
  204. if flip_ud:
  205. yim[1] *= -1.0
  206. elif i == 'reshading':
  207. yim = yim.mean(dim=0, keepdim=True)
  208. yim -= .4962
  209. yim /= 0.2846
  210. # print('reshading',yim.shape,yim.max(),yim.min())
  211. elif i == 'principal_curvature':
  212. yim = yim[:2]
  213. yim -= torch.tensor([0.5175, 0.4987]).view(2, 1, 1)
  214. yim /= torch.tensor([0.1373, 0.0359]).view(2, 1, 1)
  215. # print('principal_curvature',yim.shape,yim.max(),yim.min())
  216. elif i == 'mask':
  217. mask = yim.bool()
  218. yim = mask
  219. ys[i] = yim
  220. if mask is not None:
  221. ys['mask'] = mask
  222. if not 'rgb' in self.label_set:
  223. ys['rgb'] = im
  224. if self.return_filename:
  225. return im, ys, file_name
  226. else:
  227. return im, ys
  228. def __len__(self):
  229. return len(self.records)
  230. class DataPrefetcher:
  231. def __init__(self, loader, device):
  232. self.inital_loader = loader
  233. self.device = device
  234. self.loader = iter(loader)
  235. self.stream = torch.cuda.Stream()
  236. self.preload()
  237. def preload(self):
  238. try:
  239. self.next_input, self.next_target = next(self.loader)
  240. except StopIteration:
  241. # self.next_input = None
  242. # self.next_target = None
  243. self.loader = iter(self.inital_loader)
  244. self.preload()
  245. return
  246. with torch.cuda.stream(self.stream):
  247. self.next_input = self.next_input.to(self.device, non_blocking=True)
  248. # self.next_target = self.next_target.cuda(async=True)
  249. self.next_target = {key: val.to(self.device, non_blocking=True) for (key, val) in self.next_target.items()}
  250. def next(self):
  251. torch.cuda.current_stream().wait_stream(self.stream)
  252. input = self.next_input
  253. target = self.next_target
  254. self.preload()
  255. return input, target
  256. def update_device(self, device):
  257. self.device = device
  258. def show(im, ys):
  259. from matplotlib import pyplot as plt
  260. plt.figure(figsize=(30, 30))
  261. plt.subplot(4, 3, 1).set_title('RGB')
  262. im = im.permute([1, 2, 0])
  263. plt.imshow(im)
  264. for i, y in enumerate(ys):
  265. yim = ys[y]
  266. plt.subplot(4, 3, 2 + i).set_title(y)
  267. if y == 'normal':
  268. yim += 1
  269. yim /= 2
  270. if yim.shape[0] == 2:
  271. yim = torch.cat([yim, torch.zeros((1, yim.shape[1], yim.shape[2]))], dim=0)
  272. yim = yim.permute([1, 2, 0])
  273. yim = yim.squeeze()
  274. plt.imshow(np.array(yim))
  275. plt.show()
  276. def test():
  277. loader = TaskonomyLoader(
  278. '/home/tstand/Desktop/lite_taskonomy/',
  279. label_set=['normal', 'reshading', 'principal_curvature', 'edge_occlusion', 'depth_zbuffer'],
  280. augment='aggressive')
  281. totals = {}
  282. totals2 = {}
  283. count = {}
  284. indices = list(range(len(loader)))
  285. random.shuffle(indices)
  286. for data_count, index in enumerate(indices):
  287. im, ys = loader[index]
  288. show(im, ys)
  289. mask = ys['mask']
  290. # mask = ~mask
  291. print(index)
  292. for i, y in enumerate(ys):
  293. yim = ys[y]
  294. yim = yim.float()
  295. if y not in totals:
  296. totals[y] = 0
  297. totals2[y] = 0
  298. count[y] = 0
  299. totals[y] += (yim * mask).sum(dim=[1, 2])
  300. totals2[y] += ((yim ** 2) * mask).sum(dim=[1, 2])
  301. count[y] += (torch.ones_like(yim) * mask).sum(dim=[1, 2])
  302. # print(y,yim.shape)
  303. std = torch.sqrt((totals2[y] - (totals[y] ** 2) / count[y]) / count[y])
  304. print(data_count, '/', len(loader), y, 'mean:', totals[y] / count[y], 'std:', std)
  305. def output_mask(index, loader):
  306. filename = loader.records[index]
  307. filename = filename.replace('rgb', 'mask')
  308. filename = filename.replace('/intel_nvme/taskonomy_data/', '/run/shm/')
  309. if os.path.isfile(filename):
  310. return
  311. print(filename)
  312. x, ys = loader[index]
  313. mask = ys['mask']
  314. mask = mask.squeeze()
  315. mask_im = Image.fromarray(mask.numpy())
  316. mask_im = mask_im.convert(mode='1')
  317. # plt.subplot(2,1,1)
  318. # plt.imshow(mask)
  319. # plt.subplot(2,1,2)
  320. # plt.imshow(mask_im)
  321. # plt.show()
  322. path, _ = os.path.split(filename)
  323. os.makedirs(path, exist_ok=True)
  324. mask_im.save(filename, bits=1, optimize=True)
  325. def get_masks():
  326. loader = TaskonomyLoader(
  327. '/intel_nvme/taskonomy_data/',
  328. label_set=['depth_zbuffer'],
  329. augment=False)
  330. indices = list(range(len(loader)))
  331. random.shuffle(indices)
  332. for count, index in enumerate(indices):
  333. print(count, len(indices))
  334. output_mask(index, loader)
  335. if __name__ == "__main__":
  336. file_name = "/Users/weiming/personal-projects/taskonomy_dataset/rgb/cosmos/point_512_view_7_domain_rgb.png"
  337. pil_im = Image.open(file_name)
  338. pil_im = ImageOps.mirror(pil_im)
  339. output_size = (128, 128)
  340. pil_im = pil_im.resize(output_size, Image.BILINEAR)
  341. print(pil_im)
  342. print("Completed")