dataset.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import os
  2. from torchvision import transforms
  3. from easyfl.datasets import FederatedImageDataset
  4. DB_NAMES = ["MSMT17", "Duke", "Market", "cuhk03", "prid", "cuhk01", "viper", "3dpes", "ilids"]
  5. TRANSFORM_TRAIN_LIST = transforms.Compose([
  6. transforms.Resize((256, 128), interpolation=3),
  7. transforms.Pad(10),
  8. transforms.RandomCrop((256, 128)),
  9. transforms.RandomHorizontalFlip(),
  10. transforms.ToTensor(),
  11. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  12. ])
  13. TRANSFORM_VAL_LIST = transforms.Compose([
  14. transforms.Resize(size=(256, 128), interpolation=3),
  15. transforms.ToTensor(),
  16. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  17. ])
  18. def prepare_train_data(data_dir, db_names=None):
  19. if db_names is None:
  20. db_names = DB_NAMES
  21. client_ids = []
  22. roots = []
  23. for db in db_names:
  24. client_ids.append(db)
  25. data_path = os.path.join(data_dir, db, 'pytorch')
  26. roots.append(os.path.join(data_path, 'train_all'))
  27. data = FederatedImageDataset(root=roots,
  28. simulated=True,
  29. do_simulate=False,
  30. transform=TRANSFORM_TRAIN_LIST,
  31. client_ids=client_ids)
  32. return data
  33. def prepare_test_data(data_dir, db_names=None):
  34. if db_names is None:
  35. db_names = DB_NAMES
  36. roots = []
  37. client_ids = []
  38. for db in db_names:
  39. test_gallery = os.path.join(data_dir, db, 'pytorch', 'gallery')
  40. test_query = os.path.join(data_dir, db, 'pytorch', 'query')
  41. roots.extend([test_gallery, test_query])
  42. client_ids.extend(["{}_{}".format(db, "gallery"), "{}_{}".format(db, "query")])
  43. data = FederatedImageDataset(root=roots,
  44. simulated=True,
  45. do_simulate=False,
  46. transform=TRANSFORM_VAL_LIST,
  47. client_ids=client_ids)
  48. return data