1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import os
- from torchvision import transforms
- from easyfl.datasets import FederatedImageDataset
- DB_NAMES = ["MSMT17", "Duke", "Market", "cuhk03", "prid", "cuhk01", "viper", "3dpes", "ilids"]
- TRANSFORM_TRAIN_LIST = transforms.Compose([
- transforms.Resize((256, 128), interpolation=3),
- transforms.Pad(10),
- transforms.RandomCrop((256, 128)),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
- TRANSFORM_VAL_LIST = transforms.Compose([
- transforms.Resize(size=(256, 128), interpolation=3),
- transforms.ToTensor(),
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
- def prepare_train_data(data_dir, db_names=None):
- if db_names is None:
- db_names = DB_NAMES
- client_ids = []
- roots = []
- for db in db_names:
- client_ids.append(db)
- data_path = os.path.join(data_dir, db, 'pytorch')
- roots.append(os.path.join(data_path, 'train_all'))
- data = FederatedImageDataset(root=roots,
- simulated=True,
- do_simulate=False,
- transform=TRANSFORM_TRAIN_LIST,
- client_ids=client_ids)
- return data
- def prepare_test_data(data_dir, db_names=None):
- if db_names is None:
- db_names = DB_NAMES
- roots = []
- client_ids = []
- for db in db_names:
- test_gallery = os.path.join(data_dir, db, 'pytorch', 'gallery')
- test_query = os.path.join(data_dir, db, 'pytorch', 'query')
- roots.extend([test_gallery, test_query])
- client_ids.extend(["{}_{}".format(db, "gallery"), "{}_{}".format(db, "query")])
- data = FederatedImageDataset(root=roots,
- simulated=True,
- do_simulate=False,
- transform=TRANSFORM_VAL_LIST,
- client_ids=client_ids)
- return data
|