1234567891011121314151617181920212223242526272829303132333435 |
- import os
- from reid.utils.transform.transforms import TRANSFORM_TRAIN_LIST, TRANSFORM_VAL_LIST
- from easyfl.datasets import FederatedImageDataset
- def prepare_train_data(db_names, data_dir):
- client_ids = []
- roots = []
- for d in db_names:
- client_ids.append(d)
- data_path = os.path.join(data_dir, d, 'pytorch')
- roots.append(os.path.join(data_path, 'train_all'))
- data = FederatedImageDataset(root=roots,
- simulated=True,
- do_simulate=False,
- transform=TRANSFORM_TRAIN_LIST,
- client_ids=client_ids)
- return data
- def prepare_test_data(db_names, data_dir):
- roots = []
- client_ids = []
- for d in db_names:
- test_gallery = os.path.join(data_dir, d, 'pytorch', 'gallery')
- test_query = os.path.join(data_dir, d, 'pytorch', 'query')
- roots.extend([test_gallery, test_query])
- client_ids.extend(["{}_{}".format(d, "gallery"), "{}_{}".format(d, "query")])
- data = FederatedImageDataset(root=roots,
- simulated=True,
- do_simulate=False,
- transform=TRANSFORM_VAL_LIST,
- client_ids=client_ids)
- return data
|