dataset.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import os
  2. from reid.utils.transform.transforms import TRANSFORM_TRAIN_LIST, TRANSFORM_VAL_LIST
  3. from easyfl.datasets import FederatedImageDataset
  4. def prepare_train_data(db_names, data_dir):
  5. client_ids = []
  6. roots = []
  7. for d in db_names:
  8. client_ids.append(d)
  9. data_path = os.path.join(data_dir, d, 'pytorch')
  10. roots.append(os.path.join(data_path, 'train_all'))
  11. data = FederatedImageDataset(root=roots,
  12. simulated=True,
  13. do_simulate=False,
  14. transform=TRANSFORM_TRAIN_LIST,
  15. client_ids=client_ids)
  16. return data
  17. def prepare_test_data(db_names, data_dir):
  18. roots = []
  19. client_ids = []
  20. for d in db_names:
  21. test_gallery = os.path.join(data_dir, d, 'pytorch', 'gallery')
  22. test_query = os.path.join(data_dir, d, 'pytorch', 'query')
  23. roots.extend([test_gallery, test_query])
  24. client_ids.extend(["{}_{}".format(d, "gallery"), "{}_{}".format(d, "query")])
  25. data = FederatedImageDataset(root=roots,
  26. simulated=True,
  27. do_simulate=False,
  28. transform=TRANSFORM_VAL_LIST,
  29. client_ids=client_ids)
  30. return data