import os from torchvision import transforms from easyfl.datasets import FederatedImageDataset DB_NAMES = ["MSMT17", "Duke", "Market", "cuhk03", "prid", "cuhk01", "viper", "3dpes", "ilids"] TRANSFORM_TRAIN_LIST = transforms.Compose([ transforms.Resize((256, 128), interpolation=3), transforms.Pad(10), transforms.RandomCrop((256, 128)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) TRANSFORM_VAL_LIST = transforms.Compose([ transforms.Resize(size=(256, 128), interpolation=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def prepare_train_data(data_dir, db_names=None): if db_names is None: db_names = DB_NAMES client_ids = [] roots = [] for db in db_names: client_ids.append(db) data_path = os.path.join(data_dir, db, 'pytorch') roots.append(os.path.join(data_path, 'train_all')) data = FederatedImageDataset(root=roots, simulated=True, do_simulate=False, transform=TRANSFORM_TRAIN_LIST, client_ids=client_ids) return data def prepare_test_data(data_dir, db_names=None): if db_names is None: db_names = DB_NAMES roots = [] client_ids = [] for db in db_names: test_gallery = os.path.join(data_dir, db, 'pytorch', 'gallery') test_query = os.path.join(data_dir, db, 'pytorch', 'query') roots.extend([test_gallery, test_query]) client_ids.extend(["{}_{}".format(db, "gallery"), "{}_{}".format(db, "query")]) data = FederatedImageDataset(root=roots, simulated=True, do_simulate=False, transform=TRANSFORM_VAL_LIST, client_ids=client_ids) return data