浏览代码

feat: add implementation of iccv2023

Weiming 1 年之前
父节点
当前提交
c17c736a64

+ 1 - 0
README.md

@@ -59,6 +59,7 @@ For more advanced usage, we provide a list of tutorials on:
 
 We have released the source code for the following papers under the `applications` folder:
 
+- Federated Multiple Task Learning: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/mas) for [MAS: Towards Resource-Efficient Federated Multiple-Task Learning](https://arxiv.org/abs/2307.11285) (_ICCV'2023_)
 - FedSSL: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedssl) for two papers: [Divergence-aware Federated Self-Supervised Learning](https://openreview.net/forum?id=oVE1z8NlNe) (_ICLR'2022_)  and [Collaborative Unsupervised Visual Representation Learning From Decentralized Data](https://openaccess.thecvf.com/content/ICCV2021/html/Zhuang_Collaborative_Unsupervised_Visual_Representation_Learning_From_Decentralized_Data_ICCV_2021_paper.html) (_ICCV'2021_)
 - FedReID: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedreid) for two papers: [Performance Optimization for Federated Person Re-identification via Benchmark Analysis](https://dl.acm.org/doi/10.1145/3394171.3413814) (_ACMMM'2020_) and [Optimizing Performance of Federated Person Re-identification: Benchmarking and Analysis](https://dl.acm.org/doi/10.1145/3531013) (_TOMM_)
 - FedUReID: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedureid) for [Joint Optimization in Edge-Cloud Continuum for Federated Unsupervised Person Re-identification](https://arxiv.org/abs/2108.06493) (_ACMMM'2021_)

+ 131 - 0
applications/mas/README.md

@@ -0,0 +1,131 @@
+# MAS: Towards Resource-Efficient Federated Multiple-Task Learning
+
+
+This repository is the official release of the paper accepted to ICCV 2023 - **[MAS: Towards Resource-Efficient Federated Multiple-Task Learning](https://arxiv.org/abs/2307.11285)**
+
+In this work, we propose the first federated learning (FL) system to effectively coordinate and train multiple simultaneous FL tasks. We first formalize the problem of training simultaneous FL tasks. Then, we present our new approach, MAS (Merge and Split), to optimize the performance of training multiple simultaneous FL tasks. MAS starts by merging FL tasks into an all-in-one FL task with a multi-task architecture. After training for a few rounds, MAS splits the all-in-one FL task into two or more FL tasks by using the affinities among tasks measured during the all-in-one training. It then continues training each split of FL tasks based on model parameters from the all-in-one training. Extensive experiments demonstrate that MAS outperforms other methods while reducing training time by 2x and reducing energy consumption by 40%. We hope this work will inspire the community to further study and optimize training simultaneous FL tasks.
+
+
+| <img src="images/mas.png" width="700"> | 
+|:--:| 
+| *The architecture and workflow of our proposed <u>M</u>erge <u>a</u>nd <u>S</u>plit (MAS)* |
+
+## Prerequisite
+
+It requires the following Python libraries:
+```
+torch
+torchvision
+easyfl
+```
+
+Please refer to the [documentation](https://easyfl.readthedocs.io/en/latest/get_started.html#installation) to install `easyfl`.
+
+Note that some other libraries should be installed to start training.
+
+We referenced some implementations from:
+* https://github.com/tstandley/taskgrouping
+* https://github.com/google-research/google-research/tree/master/tag
+
+
+## Datasets
+
+We use [Taskonomy](http://taskonomy.stanford.edu/) dataset, a large and challenging computer vision dataset of indoor scenes of buildings, for all experiments in this paper. 
+
+We run experiments with 32 clients, where each client contains a dataset of one building to simulate the statistical heterogeneity. More specifically, we use the `tiny` subset of Taskonomy dataset and select 32 buildings as in [clients.txt](clients.txt). The nine tasks used in training are depth_zbuffer, edge_occlusion, edge_texture, keypoints2d, normal, principal_curvature, reshading, rgb (autoencoder), and segment_semantic. You can reference [taskonomy-tiny-data.txt](taskonomy-tiny-data.txt) or their [official repository](http://taskonomy.stanford.edu/) to download the datasets. 
+
+The following is the structure of datasets. There are 32 buildings (clients) under each task.
+```
+`-- taskonomy_datasets
+    |-- depth_zbuffer
+    |-- edge_occlusion
+    |-- edge_texture
+    |-- keypoints2d
+    |-- normal
+    |-- principal_curvature
+    |-- reshading
+    |-- rgb
+    |-- segment_semantic
+    |   |-- allensville
+    |   |-- beechwood
+    |   |-- benevolence
+    |   |-- coffeen
+    |   |-- collierville
+    |   |-- corozal
+    |   |-- cosmos
+    |   |-- darden
+    |   |-- forkland
+    |   |-- hanson
+    |   |-- hiteman
+    |   |-- ihlen
+    |   |-- klickitat
+    |   |-- lakeville
+    |   |-- leonardo
+    |   |-- lindenwood
+    |   |-- markleeville
+    |   |-- marstons
+    |   |-- mcdade
+    |   |-- merom
+    |   |-- mifflinburg
+    |   |-- muleshoe
+    |   |-- newfields
+    |   |-- noxapater
+    |   |-- onaga
+    |   |-- pinesdale
+    |   |-- pomaria
+    |   |-- ranchester
+    |   |-- shelbyville
+    |   |-- stockman
+    |   |-- tolstoy
+    |   `-- uvalda
+```
+
+We also provide several python scripts under [scripts folder](scripts/) to preprocess and sanity check the datasets after download.
+
+
+## Run the experiments
+
+We provide the following examples to run with 5 tasks `sdnkt`.
+
+> Note: the script uses `srun` from slurm server. Please adapt it to run it with Python directly. 
+
+1. **Merge**: train multiple tasks in FL, gather affinity scores among these tasks.
+
+```
+bash run.sh --tasks sdnkt --rounds 70
+```
+
+2. **Split**: get the splits of the task set based on affinity scores.
+```
+python split.py --filename <filename of training log> --split 2 --rounds 10 -p -a
+```
+
+
+3. Train each split.
+
+For example, if the task set {sdnkt} is splitted into {sd,nkt}, we can then further train each split to measure their performances.
+
+```
+bash run.sh --tasks sd --rounds 30 --pretrained y --pretrained_tasks sdnkt
+bash run.sh --tasks nkt --rounds 30 --pretrained y --pretrained_tasks sdnkt
+```
+
+You can refer to the `main.py` to run experiments with more options and configurations.
+
+> Note: 1. The codes are revised and simplified from our implementation for release, not thoroughly tested again yet. Please submit Pull Requests or Issues if you encounter any problem. 2. You can run experiments with multiple GPUs by setting `--gpu`. The default implementation supports running with multiple GPUs in a _slurm cluster_. You may need to modify `main.py` to use `multiprocess`.
+
+## Results
+
+<img src="images/results.png" width="700">
+
+## Citation
+```
+@inproceedings{zhuang2023mas,
+  title={MAS: Towards Resource-Efficient Federated Multiple-Task Learning},
+  author={Zhuang, Weiming and Wen, Yonggang and Lyu, Lingjuan and Zhang, Shuai},
+  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
+  pages={},
+  year={2023}
+}
+```
+

+ 0 - 0
applications/mas/__init__.py


+ 70 - 0
applications/mas/client.py

@@ -0,0 +1,70 @@
+import gc
+import logging
+
+import torch
+import torch._utils
+
+from losses import get_losses
+from trainer import Trainer, LR_POLY
+from easyfl.client.base import BaseClient
+from easyfl.distributed.distributed import CPU
+
+logger = logging.getLogger(__name__)
+
+
+class MASClient(BaseClient):
+    def __init__(self, cid, conf, train_data, test_data, device, sleep_time=0):
+        super(MASClient, self).__init__(cid, conf, train_data, test_data, device, sleep_time)
+        self._local_model = None
+        criteria = self.load_loss_fn(conf)
+        train_loader = self.load_loader(conf)
+        self.trainer = Trainer(self.cid, conf, train_loader, self.model, optimizer=None, criteria=criteria, device=device)
+
+    def decompression(self):
+        if self.model is None:
+            # Initialization at beginning of the task
+            self.model = self.compressed_model
+
+    def train(self, conf, device=CPU):
+        self.model.to(device)
+        optimizer = self.load_optimizer(conf)
+        self.trainer.update(self.model, optimizer, device)
+        transference = self.trainer.train()
+        if conf.lookahead == 'y':
+            logger.info(f"Round {conf.round_id} - Client {self.cid} transference: {transference}")
+
+    def load_loss_fn(self, conf):
+        criteria = get_losses(conf.task_str, conf.rotate_loss, conf.task_weights)
+        return criteria
+
+    def load_loader(self, conf):
+        train_loader = self.train_data.loader(conf.batch_size,
+                                              self.cid,
+                                              shuffle=True,
+                                              num_workers=conf.num_workers,
+                                              seed=conf.seed)
+        return train_loader
+
+    def load_optimizer(self, conf, lr=None):
+        if conf.optimizer.lr_type == LR_POLY:
+            lr = conf.optimizer.lr * pow(1 - (conf.round_id / conf.rounds), 0.9)
+        else:
+            if self.trainer.lr:
+                lr = self.trainer.lr
+            else:
+                lr = conf.optimizer.lr
+
+        optimizer = torch.optim.SGD(self.model.parameters(),
+                                    lr=lr,
+                                    momentum=conf.optimizer.momentum,
+                                    weight_decay=conf.optimizer.weight_decay)
+        return optimizer
+
+    def post_upload(self):
+        del self.model
+        del self.compressed_model
+        self.model = None
+        self.compressed_model = None
+        assert self.model is None
+        assert self.compressed_model is None
+        gc.collect()

+ 32 - 0
applications/mas/clients.txt

@@ -0,0 +1,32 @@
+allensville
+beechwood
+benevolence
+coffeen
+collierville
+corozal
+cosmos
+darden
+forkland
+hanson
+hiteman
+ihlen
+klickitat
+lakeville
+leonardo
+lindenwood
+markleeville
+marstons
+mcdade
+merom
+mifflinburg
+muleshoe
+newfields
+noxapater
+onaga
+pinesdale
+pomaria
+ranchester
+shelbyville
+stockman
+tolstoy
+uvalda

+ 410 - 0
applications/mas/dataset.py

@@ -0,0 +1,410 @@
+import os
+import os.path
+import random
+import warnings
+
+import numpy as np
+import torch
+import torch.utils.data as data
+import torchvision.transforms as transforms
+from PIL import Image, ImageOps, ImageFile
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+from easyfl.datasets import FederatedTorchDataset
+
+DEFAULT_TASKS = ['depth_zbuffer', 'normal', 'segment_semantic', 'edge_occlusion', 'reshading', 'keypoints2d', 'edge_texture']
+
+
+VAL_LIMIT = 100
+TEST_LIMIT = (1000, 2000)
+
+
+def get_dataset(data_dir, train_client_file, test_client_file, tasks, image_size, model_limit=None, half_sized_output=False, augment=False):
+    dataset = {}  # each building in taskonomy dataset is a client
+    client_ids = set()
+    with open(train_client_file) as f:
+        for line in f:
+            client_id = line.strip()
+            client_ids.add(client_id)
+            dataset[client_id] = TaskonomyLoader(data_dir,
+                                                 label_set=tasks,
+                                                 model_whitelist=[client_id],
+                                                 model_limit=model_limit,
+                                                 output_size=(image_size, image_size),
+                                                 half_sized_output=half_sized_output,
+                                                 augment=augment)
+            print(f'Client {client_id}: {len(dataset[client_id])} instances.')
+    train_set = FederatedTorchDataset(dataset, client_ids)
+
+    if augment == "aggressive":
+        print('Data augmentation is on (aggressive).')
+    elif augment:
+        print('Data augmentation is on (flip).')
+    else:
+        print('no data augmentation')
+
+    test_client_ids = set()
+    with open(test_client_file) as f:
+        for line in f:
+            test_client_ids.add(line.strip())
+
+    val_set = get_validation_data(data_dir, test_client_ids, tasks, image_size, VAL_LIMIT, half_sized_output)
+    test_set = get_validation_data(data_dir, test_client_ids, tasks, image_size, TEST_LIMIT, half_sized_output)
+
+    return train_set, val_set, test_set
+
+
+def get_validation_data(data_dir, client_ids, tasks, image_size, model_limit, half_sized_output=False):
+    dataset = TaskonomyLoader(data_dir,
+                              label_set=tasks,
+                              model_whitelist=client_ids,
+                              model_limit=model_limit,
+                              output_size=(image_size, image_size),
+                              half_sized_output=half_sized_output,
+                              augment=False)
+    if model_limit == VAL_LIMIT:
+        print(f'Found {len(dataset)} validation instances.')
+    else:
+        print(f'Found {len(dataset)} test instances.')
+    return FederatedTorchDataset(dataset, client_ids)
+
+
+class TaskonomyLoader(data.Dataset):
+    def __init__(self,
+                 root,
+                 label_set=DEFAULT_TASKS,
+                 model_whitelist=None,
+                 model_limit=None,
+                 output_size=None,
+                 convert_to_tensor=True,
+                 return_filename=False,
+                 half_sized_output=False,
+                 augment=False):
+        self.root = root
+        self.model_limit = model_limit
+        self.records = []
+        if model_whitelist is None:
+            self.model_whitelist = None
+        elif type(model_whitelist) is str:
+            self.model_whitelist = set()
+            with open(model_whitelist) as f:
+                for line in f:
+                    self.model_whitelist.add(line.strip())
+        else:
+            self.model_whitelist = model_whitelist
+
+        for i, (where, subdirs, files) in enumerate(os.walk(os.path.join(root, 'rgb'))):
+            if subdirs:
+                continue
+            model = where.split('/')[-1]
+            if self.model_whitelist is None or model in self.model_whitelist:
+                full_paths = [os.path.join(where, f) for f in files]
+                if isinstance(model_limit, tuple):
+                    full_paths.sort()
+                    full_paths = full_paths[model_limit[0]:model_limit[1]]
+                elif model_limit is not None:
+                    full_paths.sort()
+                    full_paths = full_paths[:model_limit]
+                self.records += full_paths
+
+        # self.records = manager.list(self.records)
+        self.label_set = label_set
+        self.output_size = output_size
+        self.half_sized_output = half_sized_output
+        self.convert_to_tensor = convert_to_tensor
+        self.return_filename = return_filename
+        self.to_tensor = transforms.ToTensor()
+        self.augment = augment
+
+        self.last = {}
+
+    def process_image(self, im, input=False):
+        output_size = self.output_size
+        if self.half_sized_output and not input:
+            if output_size is None:
+                output_size = (128, 128)
+            else:
+                output_size = output_size[0] // 2, output_size[1] // 2
+        if output_size is not None and output_size != im.size:
+            im = im.resize(output_size, Image.BILINEAR)
+
+        bands = im.getbands()
+        if self.convert_to_tensor:
+            if bands[0] == 'L':
+                im = np.array(im)
+                im.setflags(write=1)
+                im = torch.from_numpy(im).unsqueeze(0)
+            else:
+                with warnings.catch_warnings():
+                    warnings.simplefilter("ignore")
+                    im = self.to_tensor(im)
+
+        return im
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+
+        Returns:
+            tuple: (image, target) where target is an uint8 matrix of integers with the same width and height.
+        If there is an error loading an image or its labels, simply return the previous example.
+        """
+        with torch.no_grad():
+            file_name = self.records[index]
+            save_filename = file_name
+
+            flip_lr = (random.randint(0, 1) > .5 and self.augment)
+
+            flip_ud = (random.randint(0, 1) > .5 and (self.augment == "aggressive"))
+
+            pil_im = Image.open(file_name)
+
+            if flip_lr:
+                pil_im = ImageOps.mirror(pil_im)
+            if flip_ud:
+                pil_im = ImageOps.flip(pil_im)
+
+            im = self.process_image(pil_im, input=True)
+
+            error = False
+
+            ys = {}
+            mask = None
+            to_load = self.label_set
+            if len(set(['edge_occlusion', 'normal', 'reshading', 'principal_curvature']).intersection(
+                    self.label_set)) != 0:
+                if os.path.isfile(file_name.replace('rgb', 'mask')):
+                    to_load.append('mask')
+                elif 'depth_zbuffer' not in to_load:
+                    to_load.append('depth_zbuffer')
+
+            for i in to_load:
+                if i == 'mask' and mask is not None:
+                    continue
+
+                yfilename = file_name.replace('rgb', i)
+                try:
+                    yim = Image.open(yfilename)
+                except:
+                    yim = self.last[i].copy()
+                    error = True
+                if (i in self.last and yim.getbands() != self.last[i].getbands()) or error:
+                    yim = self.last[i].copy()
+                try:
+                    self.last[i] = yim.copy()
+                except:
+                    pass
+                if flip_lr:
+                    try:
+                        yim = ImageOps.mirror(yim)
+                    except:
+                        pass
+                if flip_ud:
+                    try:
+                        yim = ImageOps.flip(yim)
+                    except:
+                        pass
+                try:
+                    yim = self.process_image(yim)
+                except:
+                    yim = self.last[i].copy()
+                    yim = self.process_image(yim)
+
+                if i == 'depth_zbuffer':
+                    yim = yim.float()
+                    mask = yim < (2 ** 13)
+                    yim -= 1500.0
+                    yim /= 1000.0
+                elif i == 'edge_occlusion':
+                    yim = yim.float()
+                    yim -= 56.0248
+                    yim /= 239.1265
+                elif i == 'keypoints2d':
+                    yim = yim.float()
+                    yim -= 50.0
+                    yim /= 100.0
+                elif i == 'edge_texture':
+                    yim = yim.float()
+                    yim -= 718.0
+                    yim /= 1070.0
+                elif i == 'normal':
+                    yim = yim.float()
+                    yim -= .5
+                    yim *= 2.0
+                    if flip_lr:
+                        yim[0] *= -1.0
+                    if flip_ud:
+                        yim[1] *= -1.0
+                elif i == 'reshading':
+                    yim = yim.mean(dim=0, keepdim=True)
+                    yim -= .4962
+                    yim /= 0.2846
+                    # print('reshading',yim.shape,yim.max(),yim.min())
+                elif i == 'principal_curvature':
+                    yim = yim[:2]
+                    yim -= torch.tensor([0.5175, 0.4987]).view(2, 1, 1)
+                    yim /= torch.tensor([0.1373, 0.0359]).view(2, 1, 1)
+                    # print('principal_curvature',yim.shape,yim.max(),yim.min())
+                elif i == 'mask':
+                    mask = yim.bool()
+                    yim = mask
+
+                ys[i] = yim
+
+            if mask is not None:
+                ys['mask'] = mask
+
+            if not 'rgb' in self.label_set:
+                ys['rgb'] = im
+
+            if self.return_filename:
+                return im, ys, file_name
+            else:
+                return im, ys
+
+    def __len__(self):
+        return len(self.records)
+
+
+class DataPrefetcher:
+    def __init__(self, loader, device):
+        self.inital_loader = loader
+        self.device = device
+        self.loader = iter(loader)
+        self.stream = torch.cuda.Stream()
+        self.preload()
+
+    def preload(self):
+        try:
+            self.next_input, self.next_target = next(self.loader)
+        except StopIteration:
+            # self.next_input = None
+            # self.next_target = None
+            self.loader = iter(self.inital_loader)
+            self.preload()
+            return
+        with torch.cuda.stream(self.stream):
+            self.next_input = self.next_input.to(self.device, non_blocking=True)
+            # self.next_target = self.next_target.cuda(async=True)
+            self.next_target = {key: val.to(self.device, non_blocking=True) for (key, val) in self.next_target.items()}
+
+    def next(self):
+        torch.cuda.current_stream().wait_stream(self.stream)
+        input = self.next_input
+        target = self.next_target
+        self.preload()
+        return input, target
+
+    def update_device(self, device):
+        self.device = device
+
+
+def show(im, ys):
+    from matplotlib import pyplot as plt
+    plt.figure(figsize=(30, 30))
+    plt.subplot(4, 3, 1).set_title('RGB')
+    im = im.permute([1, 2, 0])
+    plt.imshow(im)
+    for i, y in enumerate(ys):
+        yim = ys[y]
+        plt.subplot(4, 3, 2 + i).set_title(y)
+        if y == 'normal':
+            yim += 1
+            yim /= 2
+        if yim.shape[0] == 2:
+            yim = torch.cat([yim, torch.zeros((1, yim.shape[1], yim.shape[2]))], dim=0)
+        yim = yim.permute([1, 2, 0])
+        yim = yim.squeeze()
+        plt.imshow(np.array(yim))
+
+    plt.show()
+
+
+def test():
+    loader = TaskonomyLoader(
+        '/home/tstand/Desktop/lite_taskonomy/',
+        label_set=['normal', 'reshading', 'principal_curvature', 'edge_occlusion', 'depth_zbuffer'],
+        augment='aggressive')
+
+    totals = {}
+    totals2 = {}
+    count = {}
+    indices = list(range(len(loader)))
+    random.shuffle(indices)
+    for data_count, index in enumerate(indices):
+        im, ys = loader[index]
+        show(im, ys)
+        mask = ys['mask']
+        # mask = ~mask
+        print(index)
+        for i, y in enumerate(ys):
+            yim = ys[y]
+            yim = yim.float()
+            if y not in totals:
+                totals[y] = 0
+                totals2[y] = 0
+                count[y] = 0
+
+            totals[y] += (yim * mask).sum(dim=[1, 2])
+            totals2[y] += ((yim ** 2) * mask).sum(dim=[1, 2])
+            count[y] += (torch.ones_like(yim) * mask).sum(dim=[1, 2])
+
+            # print(y,yim.shape)
+            std = torch.sqrt((totals2[y] - (totals[y] ** 2) / count[y]) / count[y])
+            print(data_count, '/', len(loader), y, 'mean:', totals[y] / count[y], 'std:', std)
+
+
+def output_mask(index, loader):
+    filename = loader.records[index]
+    filename = filename.replace('rgb', 'mask')
+    filename = filename.replace('/intel_nvme/taskonomy_data/', '/run/shm/')
+    if os.path.isfile(filename):
+        return
+
+    print(filename)
+
+    x, ys = loader[index]
+
+    mask = ys['mask']
+    mask = mask.squeeze()
+    mask_im = Image.fromarray(mask.numpy())
+    mask_im = mask_im.convert(mode='1')
+    # plt.subplot(2,1,1)
+    # plt.imshow(mask)
+    # plt.subplot(2,1,2)
+    # plt.imshow(mask_im)
+    # plt.show()
+    path, _ = os.path.split(filename)
+    os.makedirs(path, exist_ok=True)
+    mask_im.save(filename, bits=1, optimize=True)
+
+
+def get_masks():
+    loader = TaskonomyLoader(
+        '/intel_nvme/taskonomy_data/',
+        label_set=['depth_zbuffer'],
+        augment=False)
+
+    indices = list(range(len(loader)))
+
+    random.shuffle(indices)
+
+    for count, index in enumerate(indices):
+        print(count, len(indices))
+        output_mask(index, loader)
+
+
+if __name__ == "__main__":
+    file_name = "/Users/weiming/personal-projects/taskonomy_dataset/rgb/cosmos/point_512_view_7_domain_rgb.png"
+
+    pil_im = Image.open(file_name)
+
+    pil_im = ImageOps.mirror(pil_im)
+
+    output_size = (128, 128)
+    pil_im = pil_im.resize(output_size, Image.BILINEAR)
+
+    print(pil_im)
+    print("Completed")

二进制
applications/mas/images/mas.png


二进制
applications/mas/images/results.png


+ 295 - 0
applications/mas/losses.py

@@ -0,0 +1,295 @@
+import collections
+
+import torch
+
+sl = 0
+nl = 0
+nl2 = 0
+nl3 = 0
+dl = 0
+el = 0
+rl = 0
+kl = 0
+tl = 0
+al = 0
+cl = 0
+popular_offsets = collections.defaultdict(int)
+batch_number = 0
+
+TASKS = {
+    's': 'segment_semantic',
+    'd': 'depth_zbuffer',
+    'n': 'normal',
+    'N': 'normal2',
+    'k': 'keypoints2d',
+    'e': 'edge_occlusion',
+    'r': 'reshading',
+    't': 'edge_texture',
+    'a': 'rgb',
+    'c': 'principal_curvature'
+}
+
+LOSSES = {
+    "ss_l": 's',
+    "edge2d_l": 't',
+    "depth_l": 'd',
+    "norm_l": 'n',
+    "key_l": 'k',
+    "edge_l": 'e',
+    "shade_l": 'r',
+    "rgb_l": 'a',
+    "pc_l": 'c',
+}
+
+
+def parse_tasks(task_str):
+    tasks = []
+    for char in task_str:
+        tasks.append(TASKS[char])
+    return tasks
+
+
+def parse_loss_names(loss_names):
+    tasks = []
+    for l in loss_names:
+        tasks.append(LOSSES[l])
+    return tasks
+
+
+def segment_semantic_loss(output, target, mask):
+    global sl
+    sl = torch.nn.functional.cross_entropy(output.float(), target.long().squeeze(dim=1), ignore_index=0,
+                                           reduction='mean')
+    return sl
+
+
+def normal_loss(output, target, mask):
+    global nl
+    nl = rotate_loss(output, target, mask, normal_loss_base)
+    return nl
+
+
+def normal_loss_simple(output, target, mask):
+    global nl
+    out = torch.nn.functional.l1_loss(output, target, reduction='none')
+    out *= mask.float()
+    nl = out.mean()
+    return nl
+
+
+def rotate_loss(output, target, mask, loss_name):
+    global popular_offsets
+    target = target[:, :, 1:-1, 1:-1].float()
+    mask = mask[:, :, 1:-1, 1:-1].float()
+    output = output.float()
+    val1 = loss = loss_name(output[:, :, 1:-1, 1:-1], target, mask)
+
+    val2 = loss_name(output[:, :, 0:-2, 1:-1], target, mask)
+    loss = torch.min(loss, val2)
+    val3 = loss_name(output[:, :, 1:-1, 0:-2], target, mask)
+    loss = torch.min(loss, val3)
+    val4 = loss_name(output[:, :, 2:, 1:-1], target, mask)
+    loss = torch.min(loss, val4)
+    val5 = loss_name(output[:, :, 1:-1, 2:], target, mask)
+    loss = torch.min(loss, val5)
+    val6 = loss_name(output[:, :, 0:-2, 0:-2], target, mask)
+    loss = torch.min(loss, val6)
+    val7 = loss_name(output[:, :, 2:, 2:], target, mask)
+    loss = torch.min(loss, val7)
+    val8 = loss_name(output[:, :, 0:-2, 2:], target, mask)
+    loss = torch.min(loss, val8)
+    val9 = loss_name(output[:, :, 2:, 0:-2], target, mask)
+    loss = torch.min(loss, val9)
+
+    # lst = [val1,val2,val3,val4,val5,val6,val7,val8,val9]
+
+    # print(loss.size())
+    loss = loss.mean()
+    # print(loss)
+    return loss
+
+
+def normal_loss_base(output, target, mask):
+    out = torch.nn.functional.l1_loss(output, target, reduction='none')
+    out *= mask
+    out = out.mean(dim=(1, 2, 3))
+    return out
+
+
+def normal2_loss(output, target, mask):
+    global nl3
+    diff = output.float() - target.float()
+    out = torch.abs(diff)
+    out = out * mask.float()
+    nl3 = out.mean()
+    return nl3
+
+
+def depth_loss_simple(output, target, mask):
+    global dl
+    out = torch.nn.functional.l1_loss(output, target, reduction='none')
+    out *= mask.float()
+    dl = out.mean()
+    return dl
+
+
+def depth_loss(output, target, mask):
+    global dl
+    dl = rotate_loss(output, target, mask, depth_loss_base)
+    return dl
+
+
+def depth_loss_base(output, target, mask):
+    out = torch.nn.functional.l1_loss(output, target, reduction='none')
+    out *= mask.float()
+    out = out.mean(dim=(1, 2, 3))
+    return out
+
+
+def edge_loss_simple(output, target, mask):
+    global el
+
+    out = torch.nn.functional.l1_loss(output, target, reduction='none')
+    out *= mask
+    el = out.mean()
+    return el
+
+
+def reshade_loss(output, target, mask):
+    global rl
+    out = torch.nn.functional.l1_loss(output, target, reduction='none')
+    out *= mask
+    rl = out.mean()
+    return rl
+
+
+def keypoints2d_loss(output, target, mask):
+    global kl
+    kl = torch.nn.functional.l1_loss(output, target)
+    return kl
+
+
+def edge2d_loss(output, target, mask):
+    global tl
+    tl = torch.nn.functional.l1_loss(output, target)
+    return tl
+
+
+def auto_loss(output, target, mask):
+    global al
+    al = torch.nn.functional.l1_loss(output, target)
+    return al
+
+
+def pc_loss(output, target, mask):
+    global cl
+    out = torch.nn.functional.l1_loss(output, target, reduction='none')
+    out *= mask
+    cl = out.mean()
+    return cl
+
+
+def edge_loss(output, target, mask):
+    global el
+    out = torch.nn.functional.l1_loss(output, target, reduction='none')
+    out *= mask
+    el = out.mean()
+    return el
+
+
+def get_taskonomy_loss(losses):
+    def taskonomy_loss(output, target):
+        if 'mask' in target:
+            mask = target['mask']
+        else:
+            mask = None
+
+        sum_loss = None
+        num = 0
+        for n, t in target.items():
+            if n in losses:
+                o = output[n].float()
+                this_loss = losses[n](o, t, mask)
+                num += 1
+                if sum_loss:
+                    sum_loss = sum_loss + this_loss
+                else:
+                    sum_loss = this_loss
+
+        return sum_loss  # /num # should not take average when using xception_taskonomy_new
+
+    return taskonomy_loss
+
+
+def get_losses(task_str, is_rotate_loss, task_weights=None):
+    losses = {}
+    criteria = {}
+
+    if 's' in task_str:
+        losses['segment_semantic'] = segment_semantic_loss
+        criteria['ss_l'] = lambda x, y: sl
+
+    if 'd' in task_str:
+        if not is_rotate_loss:
+            losses['depth_zbuffer'] = depth_loss_simple
+        else:
+            losses['depth_zbuffer'] = depth_loss
+        criteria['depth_l'] = lambda x, y: dl
+
+    if 'n' in task_str:
+        if not is_rotate_loss:
+            losses['normal'] = normal_loss_simple
+        else:
+            losses['normal'] = normal_loss
+        criteria['norm_l'] = lambda x, y: nl
+        # criteria['norm_l2']=lambda x,y : nl2
+
+    if 'N' in task_str:
+        losses['normal2'] = normal2_loss
+        criteria['norm2'] = lambda x, y: nl3
+
+    if 'k' in task_str:
+        losses['keypoints2d'] = keypoints2d_loss
+        criteria['key_l'] = lambda x, y: kl
+
+    if 'e' in task_str:
+        if not is_rotate_loss:
+            losses['edge_occlusion'] = edge_loss_simple
+        else:
+            losses['edge_occlusion'] = edge_loss
+        # losses['edge_occlusion']=edge_loss
+        criteria['edge_l'] = lambda x, y: el
+
+    if 'r' in task_str:
+        losses['reshading'] = reshade_loss
+        criteria['shade_l'] = lambda x, y: rl
+
+    if 't' in task_str:
+        losses['edge_texture'] = edge2d_loss
+        criteria['edge2d_l'] = lambda x, y: tl
+
+    if 'a' in task_str:
+        losses['rgb'] = auto_loss
+        criteria['rgb_l'] = lambda x, y: al
+
+    if 'c' in task_str:
+        losses['principal_curvature'] = pc_loss
+        criteria['pc_l'] = lambda x, y: cl
+
+    if task_weights:
+        weights = [float(x) for x in task_weights.split(',')]
+        losses2 = {}
+        criteria2 = {}
+
+        for l, w, c in zip(losses.items(), weights, criteria.items()):
+            losses[l[0]] = lambda x, y, z, l=l[1], w=w: l(x, y, z) * w
+            criteria[c[0]] = lambda x, y, c=c[1], w=w: c(x, y) * w
+
+    taskonomy_loss = get_taskonomy_loss(losses)
+
+    criteria2 = {'Loss': taskonomy_loss}
+    for key, value in criteria.items():
+        criteria2[key] = value
+    criteria = criteria2
+
+    return criteria

+ 182 - 0
applications/mas/main.py

@@ -0,0 +1,182 @@
+import argparse
+import os
+
+import torch
+
+import easyfl
+from easyfl.distributed import slurm
+
+from client import MASClient
+from server import MASServer
+from dataset import get_dataset
+from losses import parse_tasks
+from models.model import get_model
+
+
+STANDALONE_CLIENT_FOLDER = "standalone_clients"
+DEFAULT_CLIENT_ID = "NA"
+
+def construct_parser(parser):
+    parser.add_argument("--task_id", type=str, default="")
+    parser.add_argument('--tasks', default='s', help='which tasks to train, options: sdnkt')
+    parser.add_argument('--task_groups', default='', help='e.g., groups of tasks separtely by comma, "sd,nkt"')
+    parser.add_argument("--dataset", type=str, default='taskonomy', help='')
+    parser.add_argument("--arch", type=str, default='xception', help='model architecture')
+    parser.add_argument('--data_dir', type=str, help='directory to load data')
+    parser.add_argument('--client_file', type=str, default='clients.txt', help='directory to load data')
+    parser.add_argument('--client_id', type=str, default=DEFAULT_CLIENT_ID, help='client id for standalone training')
+
+    parser.add_argument('--image_size', default=256, type=int, help='size of image side (images are square)')
+    parser.add_argument('--batch_size', default=64, type=int)
+    parser.add_argument('--local_epoch', default=5, type=int)
+    parser.add_argument('--rounds', default=100, type=int)
+    parser.add_argument('--num_of_clients', default=32, type=int)
+    parser.add_argument('--clients_per_round', default=5, type=int)
+    parser.add_argument('--optimizer_type', default='SGD', type=str, help='optimizer type')
+    parser.add_argument('--random_selection', action='store_true', help='whether randomly select clients')
+    parser.add_argument('--lr', default=0.1, type=float, metavar='LR', help='initial learning rate')
+    parser.add_argument('--lr_type', default="poly", type=str,
+                        help='learning rate schedule type: poly or custom, custom lr requires stateful client.')
+    parser.add_argument('--minimum_learning_rate', default=3e-5, type=float,
+                        metavar='LR', help='End training when learning rate falls below this value.')
+    parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
+    parser.add_argument('--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)')
+
+    parser.add_argument('--test_every', default=10, type=int, help='test every x rounds')
+    parser.add_argument('--save_model_every', default=10, type=int, help='save model every x rounds')
+    parser.add_argument("--aggregation_content", type=str, default="all", help="aggregation content")
+    parser.add_argument("--aggregation_strategy", type=str, default="FedAvg", help="aggregation strategy")
+    
+    parser.add_argument('--lookahead', default='y', type=str, help='whether use lookahead optimizer')
+    parser.add_argument('--lookahead_step', default=10, type=int, help='lookahead every x step')
+    parser.add_argument('--num_workers', default=4, type=int, help='number of data loading workers (default: 4)')
+    parser.add_argument('--rotate_loss', dest='rotate_loss', action='store_true', help='should loss rotation occur')
+    parser.add_argument('--pretrained', default='n', help='use pretrained model')
+    parser.add_argument('--pretrained_tasks', default='sdnkt', help='tasks for pretrained')
+    parser.add_argument('--load_decoder', default='y', help='whether load pretrained decoder')
+    parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode.')
+    parser.add_argument('--half', default='n', help='whether use half output')
+    parser.add_argument('--half_sized_output', action='store_true', help='output 128x128 rather than 256x256.')
+    parser.add_argument('--no_augment', action='store_true', help='Run model fp16 mode.')
+    parser.add_argument('--model_limit', default=None, type=int,
+                        help='Limit the number of training instances from a single 3d building model.')
+    parser.add_argument('--task_weights', default=None, type=str,
+                        help='a comma separated list of numbers one for each task to multiply the loss by.')
+    parser.add_argument('-vb', '--virtual_batch_multiplier', default=1, type=int,
+                        metavar='N', help='number of forward/backward passes per parameter update')
+    parser.add_argument('--dist_port', default=23344, type=int)
+    parser.add_argument('--run_count', default=0, type=int)
+
+    # Not effective arguments, to be deleted
+    parser.add_argument('--maximum_loss_tracking_window', default=2000000, type=int,
+                        help='maximum loss tracking window (default: 2000000)')
+
+    return parser
+
+
+def run(args):
+    rank, local_rank, world_size, host_addr = slurm.setup(args.dist_port)
+    task_id = args.task_id
+    if task_id == "":
+        task_id = f"{args.arch}_{args.tasks}_{args.clients_per_round}c{args.num_of_clients}_run{args.run_count}"
+
+    tasks = parse_tasks(args.tasks)
+
+    config = {
+        "task_id": task_id,
+        "model": args.arch,
+        "gpu": world_size,
+        "distributed": {"rank": rank, "local_rank": local_rank, "world_size": world_size, "init_method": host_addr},
+        "test_mode": "test_in_server",
+        "server": {
+            "batch_size": args.batch_size,
+            "rounds": args.rounds,
+            "test_every": args.test_every,
+            "save_model_every": args.save_model_every,
+            "clients_per_round": args.clients_per_round,
+            "test_all": False,  # False means do not test clients in the start of training
+            "random_selection": args.random_selection,
+            "aggregation_content": args.aggregation_content,
+            "aggregation_stragtegy": args.aggregation_strategy,
+            "track": False,
+        },
+        "client": {
+            "track": False,
+            "drop_last": True,
+            "batch_size": args.batch_size,
+            "local_epoch": args.local_epoch,
+            "rounds": args.rounds,
+
+            "optimizer": {
+                "type": args.optimizer_type,
+                "lr_type": args.lr_type,
+                "lr": args.lr,
+                "momentum": args.momentum,
+                "weight_decay": args.weight_decay,
+            },
+            "minimum_learning_rate": args.minimum_learning_rate,
+            
+            "tasks": tasks,
+            "task_str": args.tasks,
+            "task_weights": args.task_weights,
+            "rotate_loss": args.rotate_loss,
+
+            "lookahead": args.lookahead,
+            "lookahead_step": args.lookahead_step,
+            "num_workers": args.num_workers,
+            "fp16": args.fp16,
+            "virtual_batch_multiplier": args.virtual_batch_multiplier,
+            "maximum_loss_tracking_window": args.maximum_loss_tracking_window,
+        },
+        "tracking": {"database": os.path.join(os.getcwd(), "tracker", task_id)},
+    }
+
+    model = get_model(args.arch, tasks)
+    if args.pretrained != "n":
+        pretrained_tasks = parse_tasks(args.pretrained_tasks)
+        pretrained_model = get_model(args.arch, pretrained_tasks)
+        pretrained_path = os.path.join(os.getcwd(), "saved_models", "mas", args.pretrained)
+
+        checkpoint = torch.load(pretrained_path)
+        pretrained_model.load_state_dict(checkpoint['state_dict'])
+
+        model.encoder.load_state_dict(pretrained_model.encoder.state_dict())
+        if not args.load_decoder == "n":
+            print("load decoder")
+            pretrained_decoder_keys = list(pretrained_model.task_to_decoder.keys())
+            for i, task in enumerate(model.task_to_decoder.keys()):
+                pi = pretrained_decoder_keys.index(task)
+                model.decoders[i].load_state_dict(pretrained_model.decoders[pi].state_dict())
+
+    augment = not args.no_augment
+    client_file = args.client_file
+    if args.client_id != DEFAULT_CLIENT_ID:
+        client_file = f"{STANDALONE_CLIENT_FOLDER}/{args.client_id}.txt"
+        with open(client_file, "w") as f:
+            f.write(args.client_id)
+    if args.half == 'y':
+        args.half_sized_output = True
+    print("train client file:", client_file)
+    print("test client file:", args.client_file)
+    train_data, val_data, test_data = get_dataset(args.data_dir,
+                                                  train_client_file=client_file,
+                                                  test_client_file=args.client_file,
+                                                  tasks=tasks,
+                                                  image_size=args.image_size,
+                                                  model_limit=args.model_limit,
+                                                  half_sized_output=args.half_sized_output,
+                                                  augment=augment)
+    easyfl.register_dataset(train_data, test_data, val_data)
+    easyfl.register_model(model)
+    easyfl.register_client(MASClient)
+    easyfl.register_server(MASServer)
+    easyfl.init(config, init_all=True)
+    easyfl.run()
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='MAS')
+    parser = construct_parser(parser)
+    args = parser.parse_args()
+    print("arguments: ", args)
+    run(args)

+ 2 - 0
applications/mas/models/__init__.py

@@ -0,0 +1,2 @@
+from .resnet import *
+from .xception import *

+ 14 - 0
applications/mas/models/model.py

@@ -0,0 +1,14 @@
+from mas import models
+from easyfl.tracking.evaluation import count_model_params
+
+
+def get_model(arch, tasks, pretrained=False):
+    model = models.__dict__[arch](pretrained=pretrained, tasks=tasks)
+    print(f"Model has {count_model_params(model)} parameters")
+    try:
+        print(f"Encoder has {count_model_params(model.encoder)} parameters")
+    except:
+        print(f"Each encoder has {count_model_params(model.encoders[0])} parameters")
+    for decoder in model.task_to_decoder.values():
+        print(f"Decoder has {count_model_params(decoder)} parameters")
+    return model

+ 205 - 0
applications/mas/models/ozan_min_norm_solvers.py

@@ -0,0 +1,205 @@
+import math
+
+import numpy as np
+import torch
+
+
+class MinNormSolver:
+    MAX_ITER = 250
+    STOP_CRIT = 1e-5
+
+    def _min_norm_element_from2(v1v1, v1v2, v2v2):
+        """
+        Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
+        d is the distance (objective) optimzed
+        v1v1 = <x1,x1>
+        v1v2 = <x1,x2>
+        v2v2 = <x2,x2>
+        """
+        if v1v2 >= v1v1:
+            # Case: Fig 1, third column
+            gamma = 0.999
+            cost = v1v1
+            return gamma, cost
+        if v1v2 >= v2v2:
+            # Case: Fig 1, first column
+            gamma = 0.001
+            cost = v2v2
+            return gamma, cost
+        # Case: Fig 1, second column
+        gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
+        cost = v2v2 + gamma * (v1v2 - v2v2)
+        return gamma, cost
+
+    def _min_norm_2d(vecs, dps):
+        """
+        Find the minimum norm solution as combination of two points
+        This is correct only in 2D
+        ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
+        """
+        dmin = 1e99
+        sol = None
+        for i in range(len(vecs)):
+            for j in range(i + 1, len(vecs)):
+                if (i, j) not in dps:
+                    dps[(i, j)] = 0.0
+                    for k in range(len(vecs[i])):
+                        dps[(i, j)] += torch.dot(vecs[i][k], vecs[j][k]).item()  # .data[0]
+                    dps[(j, i)] = dps[(i, j)]
+                if (i, i) not in dps:
+                    dps[(i, i)] = 0.0
+                    for k in range(len(vecs[i])):
+                        dps[(i, i)] += torch.dot(vecs[i][k], vecs[i][k]).item()  # .data[0]
+                if (j, j) not in dps:
+                    dps[(j, j)] = 0.0
+                    for k in range(len(vecs[i])):
+                        dps[(j, j)] += torch.dot(vecs[j][k], vecs[j][k]).item()  # .data[0]
+                c, d = MinNormSolver._min_norm_element_from2(dps[(i, i)], dps[(i, j)], dps[(j, j)])
+                # print('c,d',c,d)
+                if d < dmin:
+                    dmin = d
+                    sol = [(i, j), c, d]
+
+        if sol is None or math.isnan(c):
+            raise ValueError('A numeric instability occured in ozan_min_norm_solvers.')
+        return sol, dps
+
+    def _projection2simplex(y):
+        """
+        Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
+        """
+        m = len(y)
+        sorted_y = np.flip(np.sort(y), axis=0)
+        tmpsum = 0.0
+        tmax_f = (np.sum(y) - 1.0) / m
+        for i in range(m - 1):
+            tmpsum += sorted_y[i]
+            tmax = (tmpsum - 1) / (i + 1.0)
+            if tmax > sorted_y[i + 1]:
+                tmax_f = tmax
+                break
+        return np.maximum(y - tmax_f, np.zeros(y.shape))
+
+    def _next_point(cur_val, grad, n):
+        proj_grad = grad - (np.sum(grad) / n)
+        tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
+        tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
+
+        skippers = np.sum(tm1 < 1e-7) + np.sum(tm2 < 1e-7)
+        t = 1
+        if len(tm1[tm1 > 1e-7]) > 0:
+            t = np.min(tm1[tm1 > 1e-7])
+        if len(tm2[tm2 > 1e-7]) > 0:
+            t = min(t, np.min(tm2[tm2 > 1e-7]))
+
+        next_point = proj_grad * t + cur_val
+        next_point = MinNormSolver._projection2simplex(next_point)
+        return next_point
+
+    def find_min_norm_element(vecs):
+        """
+        Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
+        as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
+        It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
+        Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
+        """
+        # Solution lying at the combination of two points
+        dps = {}
+        init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
+
+        n = len(vecs)
+        sol_vec = np.zeros(n)
+        sol_vec[init_sol[0][0]] = init_sol[1]
+        sol_vec[init_sol[0][1]] = 1 - init_sol[1]
+
+        if n < 3:
+            # This is optimal for n=2, so return the solution
+            return sol_vec, init_sol[2]
+
+        iter_count = 0
+
+        grad_mat = np.zeros((n, n))
+        for i in range(n):
+            for j in range(n):
+                grad_mat[i, j] = dps[(i, j)]
+
+        while iter_count < MinNormSolver.MAX_ITER:
+            grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
+            new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
+            # Re-compute the inner products for line search
+            v1v1 = 0.0
+            v1v2 = 0.0
+            v2v2 = 0.0
+            for i in range(n):
+                for j in range(n):
+                    v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
+                    v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
+                    v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
+            nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
+            new_sol_vec = nc * sol_vec + (1 - nc) * new_point
+            change = new_sol_vec - sol_vec
+            if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
+                return sol_vec, nd
+            sol_vec = new_sol_vec
+
+    def find_min_norm_element_FW(vecs):
+        """
+        Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
+        as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
+        It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
+        Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
+        """
+        # Solution lying at the combination of two points
+        dps = {}
+        init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
+
+        n = len(vecs)
+        sol_vec = np.zeros(n)
+        sol_vec[init_sol[0][0]] = init_sol[1]
+        sol_vec[init_sol[0][1]] = 1 - init_sol[1]
+
+        if n < 3:
+            # This is optimal for n=2, so return the solution
+            return sol_vec, init_sol[2]
+
+        iter_count = 0
+
+        grad_mat = np.zeros((n, n))
+        for i in range(n):
+            for j in range(n):
+                grad_mat[i, j] = dps[(i, j)]
+
+        while iter_count < MinNormSolver.MAX_ITER:
+            t_iter = np.argmin(np.dot(grad_mat, sol_vec))
+
+            v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
+            v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
+            v2v2 = grad_mat[t_iter, t_iter]
+
+            nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
+            new_sol_vec = nc * sol_vec
+            new_sol_vec[t_iter] += 1 - nc
+
+            change = new_sol_vec - sol_vec
+            if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
+                return sol_vec, nd
+            sol_vec = new_sol_vec
+
+
+def gradient_normalizers(grads, losses, normalization_type):
+    gn = {}
+    if normalization_type == 'l2':
+        for t in grads:
+            gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
+    elif normalization_type == 'loss':
+        for t in grads:
+            gn[t] = losses[t]
+    elif normalization_type == 'loss+':
+        for t in grads:
+            gn[t] = losses[t] * np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
+    elif normalization_type == 'none':
+        for t in grads:
+            gn[t] = 1.0
+    else:
+        print('ERROR: Invalid Normalization Type')
+    return gn

+ 222 - 0
applications/mas/models/ozan_rep_fun.py

@@ -0,0 +1,222 @@
+import math
+import statistics
+
+import torch.autograd
+
+from .ozan_min_norm_solvers import MinNormSolver
+
+
+class OzanRepFunction(torch.autograd.Function):
+    # def __init__(self,copies,noop=False):
+    #     super(OzanRepFunction,self).__init__()
+    #     self.copies=copies
+    #     self.noop=noop
+    n = 5
+
+    def __init__(self):
+        super(OzanRepFunction, self).__init__()
+
+    @staticmethod
+    def forward(ctx, input):
+
+        shape = input.shape
+        ret = input.expand(OzanRepFunction.n, *shape)
+        return ret.clone()  # REASON FOR ERROR: forgot to .clone() here
+
+    # @staticmethod
+    # def backward(ctx, grad_output):
+    #     # print("backward",grad_output.shape)
+    #     # print()
+    #     # print()
+    #     if grad_output.shape[0]==2:
+    #         theta0,theta1=grad_output[0].view(-1).float(), grad_output[1].view(-1).float()
+    #         diff = theta0-theta1
+    #         num = diff.dot(theta0)
+    #         denom = (diff.dot(diff)+.00000001)
+    #         a = num/denom
+    #         a1=float(a)
+    #         a = a.clamp(0,1)
+    #         a = float(a)
+    #         # print(float(a),a1,float(num),float(denom))
+    #         # print()
+    #         # print()
+    #         def get_out_for_a(a):
+    #             return grad_output[0]*(1-a)+grad_output[1]*a
+    #         def get_score_for_a(a):
+    #             out = get_out_for_a(a)
+    #             vec = out.view(-1)
+    #             score = vec.dot(vec)
+    #             return float(score)
+    #         # print(0,get_score_for_a(0),
+    #         #       .1,get_score_for_a(0.1),
+    #         #       .2,get_score_for_a(0.2),
+    #         #       .3,get_score_for_a(0.3),
+    #         #       .4,get_score_for_a(0.4),
+    #         #       .5,get_score_for_a(0.5),
+    #         #       .6,get_score_for_a(0.6),
+    #         #       .7,get_score_for_a(0.7),
+    #         #       .8,get_score_for_a(0.8),
+    #         #       .9,get_score_for_a(0.9),
+    #         #       1,get_score_for_a(1))
+    #         # print(a,get_score_for_a(a))
+    #         # print()
+    #         # print()
+    #         out = get_out_for_a(a)
+    #         #out=out*2
+    #     elif grad_output.shape[0]==1:
+    #         grad_input=grad_output.clone()
+    #         out = grad_input.sum(dim=0)
+    #     else:
+    #         pass
+    #     return out
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        num_grads = grad_output.shape[0]
+        batch_size = grad_output.shape[1]
+        # print(num_grads)
+        # print(num_grads)
+        # print(num_grads)
+        # print(grad_output.shape)
+        # print(grad_output.shape)
+        # print(grad_output.shape)
+        # print(num_grads)
+        # print(num_grads)
+        if num_grads >= 2:
+            # print ('shape in = ',grad_output[0].view(batch_size,-1).float().shape)
+            try:
+                alphas, score = MinNormSolver.find_min_norm_element(
+                    [grad_output[i].view(batch_size, -1).float() for i in range(num_grads)])
+                # print(alphas)
+            except ValueError as error:
+                alphas = [1 / num_grads for i in range(num_grads)]
+            # print('outs shape',out.shape)
+            # print('alphas shape',alphas.shape)
+
+            # out = out.view()
+            # out = torch.zeros_like(grad_output[0])
+            # print(alphas)
+            # print()
+            # print()
+            grad_outputs = [grad_output[i] * alphas[i] * math.sqrt(num_grads) for i in range(num_grads)]
+            output = grad_outputs[0]
+            for i in range(1, num_grads):
+                output += grad_outputs[i]
+            return output
+
+
+        elif num_grads == 1:
+            grad_input = grad_output.clone()
+            out = grad_input.sum(dim=0)
+        else:
+            pass
+        return out
+
+
+ozan_rep_function = OzanRepFunction.apply
+
+
+class TrevorRepFunction(torch.autograd.Function):
+    n = 5
+
+    def __init__(self):
+        super(TrevorRepFunction, self).__init__()
+
+    @staticmethod
+    def forward(ctx, input):
+        return input.clone()
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        # num_grads = grad_output.shape[0]
+        # print(num_grads)
+        grad_input = grad_output.clone()
+        mul = 1.0 / math.sqrt(TrevorRepFunction.n)
+        out = grad_input * mul
+        return out
+
+
+trevor_rep_function = TrevorRepFunction.apply
+
+count = 0
+
+
+class GradNormRepFunction(torch.autograd.Function):
+    n = 5
+    inital_task_losses = None
+    current_task_losses = None
+    current_weights = None
+
+    def __init__(self):
+        super(GradNormRepFunction, self).__init__()
+
+    @staticmethod
+    def forward(ctx, input):
+        shape = input.shape
+        ret = input.expand(GradNormRepFunction.n, *shape)
+        return ret.clone()
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        global count
+        num_grads = grad_output.shape[0]
+        batch_size = grad_output.shape[1]
+        grad_output = grad_output.float()
+        if num_grads >= 2:
+
+            GiW = [torch.sqrt(grad_output[i].reshape(-1).dot(grad_output[i].reshape(-1))) *
+                   GradNormRepFunction.current_weights[i] for i in range(num_grads)]
+            GW_bar = torch.mean(torch.stack(GiW))
+
+            try:
+                Li_ratio = [c / max(i, .0000001) for c, i in
+                            zip(GradNormRepFunction.current_task_losses, GradNormRepFunction.inital_task_losses)]
+                mean_ratio = statistics.mean(Li_ratio)
+                ri = [lir / max(mean_ratio, .00000001) for lir in Li_ratio]
+                target_grad = [float(GW_bar * (max(r_i, .00000001) ** 1.5)) for r_i in ri]
+
+                target_weight = [float(target_grad[i] / float(GiW[i])) for i in range(num_grads)]
+                total_weight = sum(target_weight)
+                total_weight = max(.0000001, total_weight)
+                target_weight = [i * num_grads / total_weight for i in target_weight]
+
+                for i in range(len(GradNormRepFunction.current_weights)):
+                    wi = GradNormRepFunction.current_weights[i]
+                    GradNormRepFunction.current_weights[i] += (.0001 * wi if (wi < target_weight[i]) else -.0001 * wi)
+
+                # print('Li_ratio',Li_ratio)
+                # print('mean_ratio',mean_ratio)
+                # print('ri',ri)
+                # print('target_weight',target_weight)
+                # print('current_weights',GradNormRepFunction.current_weights)
+                # print()
+                # print()
+
+                count += 1
+                if count % 80 == 0:
+                    with open("gradnorm_weights.txt", "a") as myfile:
+                        myfile.write('target: ' + str(target_weight) + '\n')
+
+                total_weight = sum(GradNormRepFunction.current_weights)
+                total_weight = max(.0000001, total_weight)
+
+                GradNormRepFunction.current_weights = [i * num_grads / total_weight for i in
+                                                       GradNormRepFunction.current_weights]
+            except:
+                pass
+
+            grad_outputs = [grad_output[i] * GradNormRepFunction.current_weights[i] * (1 / math.sqrt(num_grads)) for i
+                            in range(num_grads)]
+            output = grad_outputs[0]
+            for i in range(1, num_grads):
+                output += grad_outputs[i]
+            return output.half()
+        elif num_grads == 1:
+            grad_input = grad_output.clone()
+            out = grad_input.sum(dim=0)
+        else:
+            pass
+        return out
+
+
+gradnorm_rep_function = GradNormRepFunction.apply

+ 452 - 0
applications/mas/models/resnet.py

@@ -0,0 +1,452 @@
+import math
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .ozan_rep_fun import ozan_rep_function, trevor_rep_function, OzanRepFunction, TrevorRepFunction
+
+from easyfl.models.model import BaseModel
+
+__all__ = ['resnet18',
+           'resnet18_half',
+           'resnet18_tripple',
+           'resnet34',
+           'resnet50',
+           'resnet101',
+           'resnet152']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(BasicBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(Bottleneck, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        width = int(planes * (base_width / 64.)) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv1x1(inplanes, width)
+        self.bn1 = norm_layer(width)
+        self.conv2 = conv3x3(width, width, stride, groups, dilation)
+        self.bn2 = norm_layer(width)
+        self.conv3 = conv1x1(width, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class ResNetEncoder(nn.Module):
+
+    def __init__(self, block, layers, widths=[64, 128, 256, 512], num_classes=1000, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
+                 norm_layer=None):
+        super(ResNetEncoder, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        self._norm_layer = norm_layer
+
+        self.inplanes = 64
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+            # each element in the tuple indicates if we should replace
+            # the 2x2 stride with a dilated convolution instead
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = norm_layer(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, widths[0], layers[0])
+        self.layer2 = self._make_layer(block, widths[1], layers[1], stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block, widths[2], layers[2], stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block, widths[3], layers[3], stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+                            self.base_width, previous_dilation, norm_layer))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes, groups=self.groups,
+                                base_width=self.base_width, dilation=self.dilation,
+                                norm_layer=norm_layer))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        return x
+
+
+class Decoder(nn.Module):
+    def __init__(self, output_channels=32, num_classes=None, base_match=512):
+        super(Decoder, self).__init__()
+
+        self.output_channels = output_channels
+        self.num_classes = num_classes
+
+        self.relu = nn.ReLU(inplace=True)
+        if num_classes is not None:
+            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+            self.fc = nn.Linear(512 * block.expansion, num_classes)
+        else:
+            self.upconv0 = nn.ConvTranspose2d(base_match, 256, 2, 2)
+            self.bn_upconv0 = nn.BatchNorm2d(256)
+            self.conv_decode0 = nn.Conv2d(256, 256, 3, padding=1)
+            self.bn_decode0 = nn.BatchNorm2d(256)
+            self.upconv1 = nn.ConvTranspose2d(256, 128, 2, 2)
+            self.bn_upconv1 = nn.BatchNorm2d(128)
+            self.conv_decode1 = nn.Conv2d(128, 128, 3, padding=1)
+            self.bn_decode1 = nn.BatchNorm2d(128)
+            self.upconv2 = nn.ConvTranspose2d(128, 64, 2, 2)
+            self.bn_upconv2 = nn.BatchNorm2d(64)
+            self.conv_decode2 = nn.Conv2d(64, 64, 3, padding=1)
+            self.bn_decode2 = nn.BatchNorm2d(64)
+            self.upconv3 = nn.ConvTranspose2d(64, 48, 2, 2)
+            self.bn_upconv3 = nn.BatchNorm2d(48)
+            self.conv_decode3 = nn.Conv2d(48, 48, 3, padding=1)
+            self.bn_decode3 = nn.BatchNorm2d(48)
+            self.upconv4 = nn.ConvTranspose2d(48, 32, 2, 2)
+            self.bn_upconv4 = nn.BatchNorm2d(32)
+            self.conv_decode4 = nn.Conv2d(32, output_channels, 3, padding=1)
+
+    def forward(self, representation):
+        # batch_size=representation.shape[0]
+        if self.num_classes is None:
+            # x2 = self.conv_decode_res(representation)
+            # x2 = self.bn_conv_decode_res(x2)
+            # x2 = interpolate(x2,size=(256,256))
+
+            x = self.upconv0(representation)
+            x = self.bn_upconv0(x)
+            x = self.relu(x)
+            x = self.conv_decode0(x)
+            x = self.bn_decode0(x)
+            x = self.relu(x)
+
+            x = self.upconv1(x)
+            x = self.bn_upconv1(x)
+            x = self.relu(x)
+            x = self.conv_decode1(x)
+            x = self.bn_decode1(x)
+            x = self.relu(x)
+            x = self.upconv2(x)
+            x = self.bn_upconv2(x)
+            x = self.relu(x)
+            x = self.conv_decode2(x)
+
+            x = self.bn_decode2(x)
+            x = self.relu(x)
+            x = self.upconv3(x)
+            x = self.bn_upconv3(x)
+            x = self.relu(x)
+            x = self.conv_decode3(x)
+            x = self.bn_decode3(x)
+            x = self.relu(x)
+            x = self.upconv4(x)
+            x = self.bn_upconv4(x)
+            # x = torch.cat([x,x2],1)
+            # print(x.shape,self.static.shape)
+            # x = torch.cat([x,x2,input,self.static.expand(batch_size,-1,-1,-1)],1)
+            x = self.relu(x)
+            x = self.conv_decode4(x)
+
+            # z = x[:,19:22,:,:].clone()
+            # y = (z).norm(2,1,True).clamp(min=1e-12)
+            # print(y.shape,x[:,21:24,:,:].shape)
+            # x[:,19:22,:,:]=z/y
+
+        else:
+
+            x = F.adaptive_avg_pool2d(x, (1, 1))
+            x = x.view(x.size(0), -1)
+            x = self.fc(x)
+        return x
+
+
+class ResNet(BaseModel):
+    def __init__(self, block, layers, tasks=None, num_classes=None, ozan=False, size=1, **kwargs):
+        super(ResNet, self).__init__()
+        if size == 1:
+            self.encoder = ResNetEncoder(block, layers, **kwargs)
+        elif size == 2:
+            self.encoder = ResNetEncoder(block, layers, [96, 192, 384, 720], **kwargs)
+        elif size == 3:
+            self.encoder = ResNetEncoder(block, layers, [112, 224, 448, 880], **kwargs)
+        elif size == 0.5:
+            self.encoder = ResNetEncoder(block, layers, [48, 96, 192, 360], **kwargs)
+        self.tasks = tasks
+        self.ozan = ozan
+        self.task_to_decoder = {}
+
+        if tasks is not None:
+            # self.final_conv = nn.Conv2d(728,512,3,1,1)
+            # self.final_conv_bn = nn.BatchNorm2d(512)
+            for task in tasks:
+                if task == 'segment_semantic':
+                    output_channels = 18
+                if task == 'depth_zbuffer':
+                    output_channels = 1
+                if task == 'normal':
+                    output_channels = 3
+                if task == 'edge_occlusion':
+                    output_channels = 1
+                if task == 'reshading':
+                    output_channels = 3
+                if task == 'keypoints2d':
+                    output_channels = 1
+                if task == 'edge_texture':
+                    output_channels = 1
+                if size == 1:
+                    decoder = Decoder(output_channels)
+                elif size == 2:
+                    decoder = Decoder(output_channels, base_match=720)
+                elif size == 3:
+                    decoder = Decoder(output_channels, base_match=880)
+                elif size == 0.5:
+                    decoder = Decoder(output_channels, base_match=360)
+                self.task_to_decoder[task] = decoder
+        else:
+            self.task_to_decoder['classification'] = Decoder(output_channels=0, num_classes=1000)
+
+        self.decoders = nn.ModuleList(self.task_to_decoder.values())
+
+        # ------- init weights --------
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+        # -----------------------------
+
+    def forward(self, input):
+        rep = self.encoder(input)
+
+        if self.tasks is None:
+            return self.decoders[0](rep)
+
+        # rep = self.final_conv(rep)
+        # rep = self.final_conv_bn(rep)
+
+        outputs = {'rep': rep}
+        if self.ozan:
+            OzanRepFunction.n = len(self.decoders)
+            rep = ozan_rep_function(rep)
+            for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
+                outputs[task] = decoder(rep[i])
+        else:
+            TrevorRepFunction.n = len(self.decoders)
+            rep = trevor_rep_function(rep)
+            for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
+                outputs[task] = decoder(rep)
+
+        return outputs
+
+
+def _resnet(arch, block, layers, pretrained, **kwargs):
+    model = ResNet(block=block, layers=layers, **kwargs)
+    # if pretrained:
+    #     state_dict = load_state_dict_from_url(model_urls[arch],
+    #                                           progress=progress)
+    #     model.load_state_dict(state_dict)
+    return model
+
+
+def resnet18(pretrained=False, **kwargs):
+    """Constructs a ResNet-18 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained,
+                   **kwargs)
+
+
+def resnet18_tripple(pretrained=False, **kwargs):
+    """Constructs a ResNet-18 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, size=3,
+                   **kwargs)
+
+
+def resnet18_half(pretrained=False, **kwargs):
+    """Constructs a ResNet-18 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, size=0.5,
+                   **kwargs)
+
+
+def resnet34(pretrained=False, **kwargs):
+    """Constructs a ResNet-34 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained,
+                   **kwargs)
+
+
+def resnet50(pretrained=False, **kwargs):
+    """Constructs a ResNet-50 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained,
+                   **kwargs)
+
+
+def resnet101(pretrained=False, **kwargs):
+    """Constructs a ResNet-101 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
+                   **kwargs)
+
+
+def resnet152(pretrained=False, **kwargs):
+    """Constructs a ResNet-152 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
+                   **kwargs)

+ 821 - 0
applications/mas/models/xception.py

@@ -0,0 +1,821 @@
+""" 
+Creates an Xception Model as defined in:
+Xception: Deep Learning with Depthwise Separable Convolutions, https://arxiv.org/pdf/1610.02357.pdf
+"""
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from easyfl.models.model import BaseModel
+from .ozan_rep_fun import ozan_rep_function, OzanRepFunction, gradnorm_rep_function, GradNormRepFunction, \
+    trevor_rep_function, TrevorRepFunction
+
+__all__ = ['xception',
+           'xception_gradnorm',
+           'xception_half_gradnorm',
+           'xception_ozan',
+           'xception_half',
+           'xception_quad',
+           'xception_double',
+           'xception_double_ozan',
+           'xception_half_ozan',
+           'xception_quad_ozan']
+
+
+# model_urls = {
+#     'xception_taskonomy':'file:///home/tstand/Dropbox/taskonomy/xception_taskonomy-a4b32ef7.pth.tar'
+# }
+
+
+class SeparableConv2d(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False,
+                 groupsize=1):
+        super(SeparableConv2d, self).__init__()
+
+        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation,
+                               groups=max(1, in_channels // groupsize), bias=bias)
+        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.pointwise(x)
+        return x
+
+
+class Block(nn.Module):
+    def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
+        super(Block, self).__init__()
+
+        if out_filters != in_filters or strides != 1:
+            self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False)
+            self.skipbn = nn.BatchNorm2d(out_filters)
+        else:
+            self.skip = None
+
+        self.relu = nn.ReLU(inplace=True)
+        rep = []
+
+        filters = in_filters
+        if grow_first:
+            rep.append(self.relu)
+            rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
+            rep.append(nn.BatchNorm2d(out_filters))
+            filters = out_filters
+
+        for i in range(reps - 1):
+            rep.append(self.relu)
+            rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False))
+            rep.append(nn.BatchNorm2d(filters))
+
+        if not grow_first:
+            rep.append(self.relu)
+            rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
+            rep.append(nn.BatchNorm2d(out_filters))
+            filters = out_filters
+
+        if not start_with_relu:
+            rep = rep[1:]
+        else:
+            rep[0] = nn.ReLU(inplace=False)
+
+        if strides != 1:
+            # rep.append(nn.AvgPool2d(3,strides,1))
+            rep.append(nn.Conv2d(filters, filters, 2, 2))
+        self.rep = nn.Sequential(*rep)
+
+    def forward(self, inp):
+        x = self.rep(inp)
+
+        if self.skip is not None:
+            skip = self.skip(inp)
+            skip = self.skipbn(skip)
+        else:
+            skip = inp
+        x += skip
+        return x
+
+
+class Encoder(nn.Module):
+    def __init__(self):
+        super(Encoder, self).__init__()
+        self.conv1 = nn.Conv2d(3, 24, 3, 2, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(24)
+        self.relu = nn.ReLU(inplace=True)
+        self.relu2 = nn.ReLU(inplace=False)
+
+        self.conv2 = nn.Conv2d(24, 48, 3, 1, 1, bias=False)
+        self.bn2 = nn.BatchNorm2d(48)
+        # do relu here
+
+        self.block1 = Block(48, 96, 2, 2, start_with_relu=False, grow_first=True)
+        self.block2 = Block(96, 192, 2, 2, start_with_relu=True, grow_first=True)
+        self.block3 = Block(192, 512, 2, 2, start_with_relu=True, grow_first=True)
+
+        # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+
+        self.block8 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
+        self.block9 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
+        self.block10 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
+        self.block11 = Block(512, 512, 2, 1, start_with_relu=True, grow_first=True)
+
+        # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
+
+        self.conv3 = SeparableConv2d(512, 256, 3, 1, 1)
+        self.bn3 = nn.BatchNorm2d(256)
+        # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
+        # self.bn3 = nn.BatchNorm2d(1536)
+
+        # do relu here
+        # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
+        # self.bn4 = nn.BatchNorm2d(2048)
+
+    def forward(self, input):
+        x = self.conv1(input)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.relu(x)
+        x = self.block1(x)
+        x = self.block2(x)
+        x = self.block3(x)
+        # x = self.block4(x)
+        # x = self.block5(x)
+        # x = self.block6(x)
+        # x = self.block7(x)
+        x = self.block8(x)
+        x = self.block9(x)
+        x = self.block10(x)
+        x = self.block11(x)
+        # x = self.block12(x)
+
+        x = self.conv3(x)
+        x = self.bn3(x)
+        # x = self.relu(x)
+
+        # x = self.conv4(x)
+        # x = self.bn4(x)
+
+        representation = self.relu2(x)
+
+        return representation
+
+
+class EncoderHalf(nn.Module):
+    def __init__(self):
+        super(EncoderHalf, self).__init__()
+        self.conv1 = nn.Conv2d(3, 24, 3, 2, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(24)
+        self.relu = nn.ReLU(inplace=True)
+        self.relu2 = nn.ReLU(inplace=False)
+
+        self.conv2 = nn.Conv2d(24, 48, 3, 1, 1, bias=False)
+        self.bn2 = nn.BatchNorm2d(48)
+        # do relu here
+
+        self.block1 = Block(48, 64, 2, 2, start_with_relu=False, grow_first=True)
+        self.block2 = Block(64, 128, 2, 2, start_with_relu=True, grow_first=True)
+        self.block3 = Block(128, 360, 2, 2, start_with_relu=True, grow_first=True)
+
+        # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+
+        self.block8 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
+        self.block9 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
+        self.block10 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
+        self.block11 = Block(360, 360, 2, 1, start_with_relu=True, grow_first=True)
+
+        # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
+
+        self.conv3 = SeparableConv2d(360, 256, 3, 1, 1)
+        self.bn3 = nn.BatchNorm2d(256)
+        # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
+        # self.bn3 = nn.BatchNorm2d(1536)
+
+        # do relu here
+        # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
+        # self.bn4 = nn.BatchNorm2d(2048)
+
+    def forward(self, input):
+        x = self.conv1(input)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.relu(x)
+        x = self.block1(x)
+        x = self.block2(x)
+        x = self.block3(x)
+        # x = self.block4(x)
+        # x = self.block5(x)
+        # x = self.block6(x)
+        # x = self.block7(x)
+        x = self.block8(x)
+        x = self.block9(x)
+        x = self.block10(x)
+        x = self.block11(x)
+        # x = self.block12(x)
+
+        x = self.conv3(x)
+        x = self.bn3(x)
+        # x = self.relu(x)
+
+        # x = self.conv4(x)
+        # x = self.bn4(x)
+
+        representation = self.relu2(x)
+
+        return representation
+
+
+class EncoderQuad(nn.Module):
+    def __init__(self):
+        super(EncoderQuad, self).__init__()
+        print('entering quad constructor')
+        self.conv1 = nn.Conv2d(3, 48, 3, 2, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(48)
+        self.relu = nn.ReLU(inplace=True)
+        self.relu2 = nn.ReLU(inplace=False)
+
+        self.conv2 = nn.Conv2d(48, 96, 3, 1, 1, bias=False)
+        self.bn2 = nn.BatchNorm2d(96)
+        # do relu here
+
+        self.block1 = Block(96, 192, 2, 2, start_with_relu=False, grow_first=True)
+        self.block2 = Block(192, 384, 2, 2, start_with_relu=True, grow_first=True)
+        self.block3 = Block(384, 1024, 2, 2, start_with_relu=True, grow_first=True)
+
+        # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+
+        self.block8 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
+        self.block9 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
+        self.block10 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
+        self.block11 = Block(1024, 1024, 2, 1, start_with_relu=True, grow_first=True)
+
+        # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
+
+        self.conv3 = SeparableConv2d(1024, 256, 3, 1, 1)
+        self.bn3 = nn.BatchNorm2d(256)
+        # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
+        # self.bn3 = nn.BatchNorm2d(1536)
+
+        # do relu here
+        # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
+        # self.bn4 = nn.BatchNorm2d(2048)
+
+    def forward(self, input):
+        x = self.conv1(input)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.relu(x)
+        x = self.block1(x)
+        x = self.block2(x)
+        x = self.block3(x)
+        # x = self.block4(x)
+        # x = self.block5(x)
+        # x = self.block6(x)
+        # x = self.block7(x)
+        x = self.block8(x)
+        x = self.block9(x)
+        x = self.block10(x)
+        x = self.block11(x)
+        # x = self.block12(x)
+
+        x = self.conv3(x)
+        x = self.bn3(x)
+        # x = self.relu(x)
+
+        # x = self.conv4(x)
+        # x = self.bn4(x)
+
+        representation = self.relu2(x)
+
+        return representation
+
+
+class EncoderDouble(nn.Module):
+    def __init__(self):
+        super(EncoderDouble, self).__init__()
+        print('entering double constructor')
+        self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(32)
+        self.relu = nn.ReLU(inplace=True)
+        self.relu2 = nn.ReLU(inplace=False)
+
+        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
+        self.bn2 = nn.BatchNorm2d(64)
+        # do relu here
+
+        self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
+        self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
+        self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
+
+        # self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+        # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)
+
+        self.block8 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
+        self.block9 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
+        self.block10 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
+        self.block11 = Block(728, 728, 2, 1, start_with_relu=True, grow_first=True)
+
+        # self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
+
+        self.conv3 = SeparableConv2d(728, 256, 3, 1, 1)
+        self.bn3 = nn.BatchNorm2d(256)
+        # self.conv3 = SeparableConv2d(1024,1536,3,1,1)
+        # self.bn3 = nn.BatchNorm2d(1536)
+
+        # do relu here
+        # self.conv4 = SeparableConv2d(1536,2048,3,1,1)
+        # self.bn4 = nn.BatchNorm2d(2048)
+
+    def forward(self, input):
+        x = self.conv1(input)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.relu(x)
+        x = self.block1(x)
+        x = self.block2(x)
+        x = self.block3(x)
+        # x = self.block4(x)
+        # x = self.block5(x)
+        # x = self.block6(x)
+        # x = self.block7(x)
+        x = self.block8(x)
+        x = self.block9(x)
+        x = self.block10(x)
+        x = self.block11(x)
+        # x = self.block12(x)
+
+        x = self.conv3(x)
+        x = self.bn3(x)
+        # x = self.relu(x)
+
+        # x = self.conv4(x)
+        # x = self.bn4(x)
+
+        representation = self.relu2(x)
+
+        return representation
+
+
+def interpolate(inp, size):
+    t = inp.type()
+    inp = inp.float()
+    out = nn.functional.interpolate(inp, size=size, mode='bilinear', align_corners=False)
+    if out.type() != t:
+        out = out.half()
+    return out
+
+
+class Decoder(nn.Module):
+    def __init__(self, output_channels=32, num_classes=None):
+        super(Decoder, self).__init__()
+
+        self.output_channels = output_channels
+        self.num_classes = num_classes
+
+        if num_classes is not None:
+            self.fc = nn.Linear(256, num_classes)
+        # else:
+        #    self.fc = nn.Linear(256, 1000)
+        else:
+            self.relu = nn.ReLU(inplace=True)
+
+            self.conv_decode_res = SeparableConv2d(256, 16, 3, padding=1)
+            self.conv_decode_res2 = SeparableConv2d(256, 96, 3, padding=1)
+            self.bn_conv_decode_res = nn.BatchNorm2d(16)
+            self.bn_conv_decode_res2 = nn.BatchNorm2d(96)
+            self.upconv1 = nn.ConvTranspose2d(96, 96, 2, 2)
+            self.bn_upconv1 = nn.BatchNorm2d(96)
+            self.conv_decode1 = SeparableConv2d(96, 64, 3, padding=1)
+            self.bn_decode1 = nn.BatchNorm2d(64)
+            self.upconv2 = nn.ConvTranspose2d(64, 64, 2, 2)
+            self.bn_upconv2 = nn.BatchNorm2d(64)
+            self.conv_decode2 = SeparableConv2d(64, 64, 5, padding=2)
+            self.bn_decode2 = nn.BatchNorm2d(64)
+            self.upconv3 = nn.ConvTranspose2d(64, 32, 2, 2)
+            self.bn_upconv3 = nn.BatchNorm2d(32)
+            self.conv_decode3 = SeparableConv2d(32, 32, 5, padding=2)
+            self.bn_decode3 = nn.BatchNorm2d(32)
+            self.upconv4 = nn.ConvTranspose2d(32, 32, 2, 2)
+            self.bn_upconv4 = nn.BatchNorm2d(32)
+            self.conv_decode4 = SeparableConv2d(48, output_channels, 5, padding=2)
+
+    def forward(self, representation):
+        # batch_size=representation.shape[0]
+        if self.num_classes is None:
+            x2 = self.conv_decode_res(representation)
+            x2 = self.bn_conv_decode_res(x2)
+            x2 = interpolate(x2, size=(256, 256))
+            x = self.conv_decode_res2(representation)
+            x = self.bn_conv_decode_res2(x)
+            x = self.upconv1(x)
+            x = self.bn_upconv1(x)
+            x = self.relu(x)
+            x = self.conv_decode1(x)
+            x = self.bn_decode1(x)
+            x = self.relu(x)
+            x = self.upconv2(x)
+            x = self.bn_upconv2(x)
+            x = self.relu(x)
+            x = self.conv_decode2(x)
+
+            x = self.bn_decode2(x)
+            x = self.relu(x)
+            x = self.upconv3(x)
+            x = self.bn_upconv3(x)
+            x = self.relu(x)
+            x = self.conv_decode3(x)
+            x = self.bn_decode3(x)
+            x = self.relu(x)
+            x = self.upconv4(x)
+            x = self.bn_upconv4(x)
+            x = torch.cat([x, x2], 1)
+            # print(x.shape,self.static.shape)
+            # x = torch.cat([x,x2,input,self.static.expand(batch_size,-1,-1,-1)],1)
+            x = self.relu(x)
+            x = self.conv_decode4(x)
+
+            # z = x[:,19:22,:,:].clone()
+            # y = (z).norm(2,1,True).clamp(min=1e-12)
+            # print(y.shape,x[:,21:24,:,:].shape)
+            # x[:,19:22,:,:]=z/y
+
+        else:
+            # print(representation.shape)
+            x = F.adaptive_avg_pool2d(representation, (1, 1))
+            x = x.view(x.size(0), -1)
+            # print(x.shape)
+            x = self.fc(x)
+            # print(x.shape)
+        return x
+
+
+class Xception(BaseModel):
+    """
+    Xception optimized for the ImageNet dataset, as specified in
+    https://arxiv.org/pdf/1610.02357.pdf
+    """
+
+    def __init__(self, tasks=None, num_classes=None, ozan=False, half=False):
+        """ Constructor
+        Args:
+            num_classes: number of classes
+        """
+        super(Xception, self).__init__()
+        print('half is', half)
+        if half == 'Quad':
+            print('running quad code')
+            self.encoder = EncoderQuad()
+        elif half == 'Double':
+            self.encoder = EncoderDouble()
+        elif half:
+            self.encoder = EncoderHalf()
+        else:
+            self.encoder = Encoder()
+        self.tasks = tasks
+        self.ozan = ozan
+        self.task_to_decoder = {}
+        self.task_to_output_channels = {
+            'segment_semantic': 18,
+            'depth_zbuffer': 1,
+            'normal': 3,
+            'normal2': 3,
+            'edge_occlusion': 1,
+            'reshading': 1,
+            'keypoints2d': 1,
+            'edge_texture': 1,
+            'principal_curvature': 2,
+            'rgb': 3,
+        }
+        if tasks is not None:
+            for task in tasks:
+                output_channels = self.task_to_output_channels[task]
+                decoder = Decoder(output_channels, num_classes)
+                self.task_to_decoder[task] = decoder
+        else:
+            self.task_to_decoder['classification'] = Decoder(output_channels=0, num_classes=1000)
+
+        self.decoders = nn.ModuleList(self.task_to_decoder.values())
+
+        # ------- init weights --------
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+        # -----------------------------
+
+    count = 0
+
+    def input_per_task_losses(self, losses):
+        # if GradNormRepFunction.inital_task_losses is None:
+        #     GradNormRepFunction.inital_task_losses=losses
+        #     GradNormRepFunction.current_weights=[1 for i in losses]
+        Xception.count += 1
+        if Xception.count < 200:
+            GradNormRepFunction.inital_task_losses = losses
+            GradNormRepFunction.current_weights = [1 for i in losses]
+        elif Xception.count % 20 == 0:
+            with open("gradnorm_weights.txt", "a") as myfile:
+                myfile.write(str(Xception.count) + ': ' + str(GradNormRepFunction.current_weights) + '\n')
+        GradNormRepFunction.current_task_losses = losses
+
+    def forward(self, input):
+        rep = self.encoder(input)
+
+        if self.tasks is None:
+            return self.decoders[0](rep)
+
+        outputs = {'rep': rep}
+        if self.ozan == 'gradnorm':
+            GradNormRepFunction.n = len(self.decoders)
+            rep = gradnorm_rep_function(rep)
+            for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
+                outputs[task] = decoder(rep[i])
+        elif self.ozan:
+            OzanRepFunction.n = len(self.decoders)
+            rep = ozan_rep_function(rep)
+            for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
+                outputs[task] = decoder(rep[i])
+        else:
+            TrevorRepFunction.n = len(self.decoders)
+            rep = trevor_rep_function(rep)
+            for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
+                outputs[task] = decoder(rep)
+            # Original loss
+            # for i, (task, decoder) in enumerate(zip(self.task_to_decoder.keys(), self.decoders)):
+            #     outputs[task] = decoder(rep)
+
+        return outputs
+
+
+def xception(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+    model = Xception(**kwargs)
+
+    if pretrained:
+        # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
+        # for name,weight in state_dict.items():
+        #     if 'pointwise' in name:
+        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
+        #     if 'conv1' in name and len(weight.shape)!=4:
+        #         state_dict[name]=weight.unsqueeze(1)
+        # model.load_state_dict(state_dict)
+        # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
+        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
+        # if num_classes !=1000:
+        #     model.fc = nn.Linear(2048, num_classes)
+        # import torch
+        # print("writing new state dict")
+        # torch.save(model.state_dict(),"xception.pth.tar")
+        # print("done")
+        # import sys
+        # sys.exit(1)
+
+    return model
+
+
+def xception_ozan(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+
+    model = Xception(ozan=True, **kwargs)
+
+    if pretrained:
+        # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
+        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
+
+    return model
+
+
+def xception_gradnorm(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+
+    model = Xception(ozan='gradnorm', **kwargs)
+
+    if pretrained:
+        # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
+        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
+
+    return model
+
+def xception_half_gradnorm(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+
+    model = Xception(half=True, ozan='gradnorm', **kwargs)
+
+    return model
+
+def xception_half(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+    # try:
+    #     num_classes = kwargs['num_classes']
+    # except:
+    #     num_classes=1000
+    # if pretrained:
+    #     kwargs['num_classes']=1000
+    model = Xception(half=True, **kwargs)
+
+    if pretrained:
+        # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
+        # for name,weight in state_dict.items():
+        #     if 'pointwise' in name:
+        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
+        #     if 'conv1' in name and len(weight.shape)!=4:
+        #         state_dict[name]=weight.unsqueeze(1)
+        # model.load_state_dict(state_dict)
+        # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
+        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
+        # if num_classes !=1000:
+        #     model.fc = nn.Linear(2048, num_classes)
+        # import torch
+        # print("writing new state dict")
+        # torch.save(model.state_dict(),"xception.pth.tar")
+        # print("done")
+        # import sys
+        # sys.exit(1)
+
+    return model
+
+
+def xception_quad(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+    # try:
+    #     num_classes = kwargs['num_classes']
+    # except:
+    #     num_classes=1000
+    # if pretrained:
+    #     kwargs['num_classes']=1000
+    print('got quad')
+    model = Xception(half='Quad', **kwargs)
+
+    if pretrained:
+        # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
+        # for name,weight in state_dict.items():
+        #     if 'pointwise' in name:
+        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
+        #     if 'conv1' in name and len(weight.shape)!=4:
+        #         state_dict[name]=weight.unsqueeze(1)
+        # model.load_state_dict(state_dict)
+        # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
+        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
+        # if num_classes !=1000:
+        #     model.fc = nn.Linear(2048, num_classes)
+        # import torch
+        # print("writing new state dict")
+        # torch.save(model.state_dict(),"xception.pth.tar")
+        # print("done")
+        # import sys
+        # sys.exit(1)
+
+    return model
+
+
+def xception_double(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+    # try:
+    #     num_classes = kwargs['num_classes']
+    # except:
+    #     num_classes=1000
+    # if pretrained:
+    #     kwargs['num_classes']=1000
+    print('got double')
+    model = Xception(half='Double', **kwargs)
+
+    if pretrained:
+        # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
+        # for name,weight in state_dict.items():
+        #     if 'pointwise' in name:
+        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
+        #     if 'conv1' in name and len(weight.shape)!=4:
+        #         state_dict[name]=weight.unsqueeze(1)
+        # model.load_state_dict(state_dict)
+        # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
+        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
+        # if num_classes !=1000:
+        #     model.fc = nn.Linear(2048, num_classes)
+        # import torch
+        # print("writing new state dict")
+        # torch.save(model.state_dict(),"xception.pth.tar")
+        # print("done")
+        # import sys
+        # sys.exit(1)
+
+    return model
+
+
+def xception_quad_ozan(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+    # try:
+    #     num_classes = kwargs['num_classes']
+    # except:
+    #     num_classes=1000
+    # if pretrained:
+    #     kwargs['num_classes']=1000
+    print('got quad ozan')
+    model = Xception(ozan=True, half='Quad', **kwargs)
+
+    if pretrained:
+        # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
+        # for name,weight in state_dict.items():
+        #     if 'pointwise' in name:
+        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
+        #     if 'conv1' in name and len(weight.shape)!=4:
+        #         state_dict[name]=weight.unsqueeze(1)
+        # model.load_state_dict(state_dict)
+        # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
+        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
+        # if num_classes !=1000:
+        #     model.fc = nn.Linear(2048, num_classes)
+        # import torch
+        # print("writing new state dict")
+        # torch.save(model.state_dict(),"xception.pth.tar")
+        # print("done")
+        # import sys
+        # sys.exit(1)
+
+    return model
+
+
+def xception_double_ozan(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+    # try:
+    #     num_classes = kwargs['num_classes']
+    # except:
+    #     num_classes=1000
+    # if pretrained:
+    #     kwargs['num_classes']=1000
+    print('got double')
+    model = Xception(ozan=True, half='Double', **kwargs)
+
+    if pretrained:
+        # state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
+        # for name,weight in state_dict.items():
+        #     if 'pointwise' in name:
+        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
+        #     if 'conv1' in name and len(weight.shape)!=4:
+        #         state_dict[name]=weight.unsqueeze(1)
+        # model.load_state_dict(state_dict)
+        # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
+        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
+        # if num_classes !=1000:
+        #     model.fc = nn.Linear(2048, num_classes)
+        # import torch
+        # print("writing new state dict")
+        # torch.save(model.state_dict(),"xception.pth.tar")
+        # print("done")
+        # import sys
+        # sys.exit(1)
+
+    return model
+
+
+def xception_half_ozan(pretrained=False, **kwargs):
+    """
+    Construct Xception.
+    """
+
+    model = Xception(ozan=True, half=True, **kwargs)
+
+    if pretrained:
+        # model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
+        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
+
+    return model

+ 207 - 0
applications/mas/network_selection.py

@@ -0,0 +1,207 @@
+import copy
+
+import numpy as np
+from sympy import symbols, Eq, solve, Rational
+
+
+def gen_task_combinations(affinity, tasks, rtn, index, path, path_dict, task_overlap):
+    if index >= len(tasks):
+        return
+
+    for i in range(index, len(tasks)):
+        cur_task = tasks[i]
+        new_path = path
+        new_dict = {k: v for k, v in path_dict.items()}
+
+        # Building from a tree with two or more tasks...
+        if new_path:
+            new_dict[cur_task] = 0.
+            for prev_task in path_dict:
+                new_dict[prev_task] += affinity[prev_task][cur_task]
+                new_dict[cur_task] += affinity[cur_task][prev_task]
+            new_path = '{}|{}'.format(new_path, cur_task)
+            rtn[new_path] = new_dict
+        else:  # First element in a new-formed tree
+            new_dict[cur_task] = 0.
+            new_path = cur_task
+
+        gen_task_combinations(affinity, tasks, rtn, i + 1, new_path, new_dict, task_overlap)
+
+        if '|' not in new_path:
+            if task_overlap:
+                new_dict[cur_task] = -1e6
+            else:
+                new_dict[cur_task] = average_of_self_to_others_and_others_to_self(cur_task, affinity)
+
+            rtn[new_path] = new_dict
+
+
+def average_of_self_to_others(cur_task, affinity):
+    scores = [score for task, score in affinity[cur_task].items() if task != cur_task]
+    return sum(scores) / len(scores)
+
+
+def average_of_others_to_self(cur_task, affinity):
+    scores = [score for source_task, a in affinity.items() for target_task, score in a.items()
+              if source_task != cur_task and target_task == cur_task]
+    return sum(scores) / len(scores)
+
+
+def average_of_self_to_others_and_others_to_self(cur_task, affinity):
+    scores1 = [score for task, score in affinity[cur_task].items() if task != cur_task]
+    scores2 = [score for source_task, a in affinity.items() for target_task, score in a.items()
+               if source_task != cur_task and target_task == cur_task]
+    return (sum(scores1) + sum(scores2)) / (len(scores1) + len(scores2))
+
+
+def select_groups(affinity, rtn_tup, index, cur_group, best_group, best_val, splits, task_overlap=True):
+    # Check if this group covers all tasks.
+    num_tasks = len(affinity.keys())
+    if task_overlap:
+        task_set = set()
+        for group in cur_group:
+            for task in group.split('|'): task_set.add(task)
+    else:
+        task_set = list()
+        for group in cur_group:
+            for task in group.split('|'):
+                if task in task_set:
+                    return
+                else:
+                    task_set.append(task)
+    if len(task_set) == num_tasks:
+        best_tasks = {task: -1e6 for task in task_set}
+
+        # Compute the per-task best scores for each task and average them together.
+        for group in cur_group:
+            for task in cur_group[group]:
+                best_tasks[task] = max(best_tasks[task], cur_group[group][task])
+        group_avg = np.mean(list(best_tasks.values()))
+
+        # Compare with the best grouping seen thus far.
+        if group_avg > best_val[0]:
+            # print(cur_group)
+            if task_overlap or no_task_overlap(cur_group, num_tasks):
+                best_val[0] = group_avg
+                best_group.clear()
+                for entry in cur_group:
+                    best_group[entry] = cur_group[entry]
+
+    # Base case.
+    if len(cur_group.keys()) == splits:
+        return
+
+    # Back to combinatorics
+    for i in range(index, len(rtn_tup)):
+        selected_group, selected_dict = rtn_tup[i]
+
+        new_group = {k: v for k, v in cur_group.items()}
+        new_group[selected_group] = selected_dict
+
+        if len(new_group.keys()) <= splits:
+            select_groups(affinity, rtn_tup, i + 1, new_group, best_group, best_val, splits, task_overlap)
+
+
+def task_grouping(affinity, task_overlap=True, split=3):
+    tasks = list(affinity.keys())
+    rtn = {}
+
+    gen_task_combinations(affinity, tasks=tasks, rtn=rtn, index=0, path='', path_dict={}, task_overlap=task_overlap)
+
+    # Normalize by the number of times the accuracy of any given element has been summed.
+    # i.e. (a,b,c) => [acc(a|b) + acc(a|c)]/2 + [acc(b|a) + acc(b|c)]/2 + [acc(c|a) + acc(c|b)]/2
+    for group in rtn:
+        if '|' in group:
+            for task in rtn[group]:
+                rtn[group][task] /= (len(group.split('|')) - 1)
+
+    assert (len(rtn.keys()) == 2 ** len(affinity.keys()) - 1)
+    rtn_tup = [(key, val) for key, val in rtn.items()]
+
+    # if not task_overlap:
+    #     rtn_tup = calculate_self_affinity(affinity, rtn_tup)
+    selected_group = {}
+    selected_val = [-100000000]
+    select_groups(affinity, rtn_tup, index=0, cur_group={}, best_group=selected_group, best_val=selected_val,
+                  splits=split, task_overlap=task_overlap)
+    return list(selected_group.keys())
+
+
+def rtn_tup_to_dict(rtn_tup):
+    d = {}
+    for tup in rtn_tup:
+        d[tup[0]] = tup[1]
+    return d
+
+
+def rtn_dict_to_tup(rtn_dict):
+    rtn_tup = []
+    for key, value in rtn_dict.items():
+        rtn_tup.append((key, value))
+    return rtn_tup
+
+
+def calculate_self_affinity(affinity, rtn_tup):
+    rtn_dict = rtn_tup_to_dict(rtn_tup)
+
+    task_names = list(affinity.keys())
+    tasks = symbols(" ".join(task_names))
+    for i, t in enumerate(task_names):
+        rtn_dict[t] = tasks[i]
+
+    equations = []
+    for i, task in enumerate(task_names):
+        task_combs = [comb for comb in rtn_dict.keys() if task in comb]
+        count = len(task_combs) - 1
+        eq = Rational(0)
+        name1 = task + "|"
+        name2 = "|" + task
+        for comb in task_combs:
+            if comb == task:
+                eq -= count * rtn_dict[comb]
+                continue
+            sub_comb = comb.replace(name1, "") if name1 in comb else comb.replace(name2, "")
+            sub = rtn_dict[sub_comb] if "|" not in sub_comb else sum(rtn_dict[sub_comb].values())
+            eq += sum(rtn_dict[comb].values()) - sub
+        equations.append(Eq(eq, 0))
+    sol = solve(equations, tasks)
+    for i, t in enumerate(task_names):
+        rtn_dict[t] = {t: sol[tasks[i]]}
+
+    rtn_tup = rtn_dict_to_tup(rtn_dict)
+    return rtn_tup
+
+
+def no_task_overlap(group, num_tasks):
+    task_set = list()
+    for combination in group.keys():
+        for task in combination.split("|"):
+            if task not in task_set:
+                task_set.append(task)
+            else:
+                return False
+    return len(task_set) == num_tasks
+
+
+def average_task_affinity_among_clients(affinities):
+    result = copy.deepcopy(affinities[0])
+    for task, affinity in result.items():
+        for target_task, score in affinity.items():
+            total = score
+            for a in affinities[1:]:
+                total += a[task][target_task]
+            result[task][target_task] = total / len(affinities)
+    return result
+
+
+def run(affinities):
+    results = []
+    averaged_affinity = average_task_affinity_among_clients(affinities)
+    groups = task_grouping(averaged_affinity, task_overlap=True)
+    results.append(groups)
+    for i, a in enumerate(affinities):
+        print("client", i)
+        groups = task_grouping(a, task_overlap=True)
+        results.append(groups)
+    print(results)
+    return results

+ 80 - 0
applications/mas/scripts/check_folders.py

@@ -0,0 +1,80 @@
+import argparse
+import os
+
+folders = [
+    "allensville",
+    "beechwood",
+    "benevolence",
+    "coffeen",
+    "collierville",
+    "corozal",
+    "cosmos",
+    "darden",
+    "forkland",
+    "hanson",
+    "hiteman",
+    "ihlen",
+    "klickitat",
+    "lakeville",
+    "leonardo",
+    "lindenwood",
+    "markleeville",
+    "marstons",
+    "mcdade",
+    "merom",
+    "mifflinburg",
+    "muleshoe",
+    "newfields",
+    "noxapater",
+    "onaga",
+    "pinesdale",
+    "pomaria",
+    "ranchester",
+    "shelbyville",
+    "stockman",
+    "tolstoy",
+    "uvalda",
+]
+
+TASKS = {
+    's': 'segment_semantic',
+    'd': 'depth_zbuffer',
+    'n': 'normal',
+    'N': 'normal2',
+    'k': 'keypoints2d',
+    'e': 'edge_occlusion',
+    'r': 'reshading',
+    't': 'edge_texture',
+    'a': 'rgb',
+    'c': 'principal_curvature'
+}
+
+
+def parse_tasks(task_str):
+    tasks = []
+    for char in task_str:
+        tasks.append(TASKS[char])
+    return tasks
+
+
+def run():
+    parser = argparse.ArgumentParser(description='Extract')
+    parser.add_argument("--dir", type=str)
+    parser.add_argument('--tasks', type=str)
+
+    args = parser.parse_args()
+
+    tasks = parse_tasks(args.tasks)
+
+    for f in folders:
+        for t in tasks:
+            p = os.path.join(args.dir, t, f)
+            try:
+                print(f"{t}-{f}: {len(os.listdir(p))}")
+            except Exception as e:
+                print(e)
+        print()
+
+
+if __name__ == '__main__':
+    run()

+ 38 - 0
applications/mas/scripts/extract_data.py

@@ -0,0 +1,38 @@
+import argparse
+import os
+import shutil
+import tarfile
+
+
+def run():
+    parser = argparse.ArgumentParser(description='Extract')
+    parser.add_argument("--source", type=str)
+    parser.add_argument('--target', type=str)
+    parser.add_argument('--task', type=str)
+
+    args = parser.parse_args()
+
+    files = os.listdir(args.source)
+    for file in files:
+        if args.task in file:
+            print(f"Processing {file}")
+            try:
+                source_path = os.path.join(args.source, file)
+                target_path = os.path.join(args.target, args.task)
+                file_obj = tarfile.open(source_path, "r")
+                file_obj.extractall(target_path)
+                file_obj.close()
+                old_name = os.path.join(target_path, args.task)
+                place = file.replace(args.task, "").replace("_.tar", "")
+                new_name = os.path.join(target_path, place)
+                shutil.move(old_name, new_name)
+                print(f"Extracted {file}")
+            except Exception as e:
+                print()
+                print(f"Failed to extract {file}")
+                print(e)
+                print()
+
+
+if __name__ == '__main__':
+    run()

+ 69 - 0
applications/mas/scripts/rename_segment_semantic.py

@@ -0,0 +1,69 @@
+import argparse
+import os
+import shutil
+
+folders = [
+    "allensville",
+    "beechwood",
+    "benevolence",
+    "coffeen",
+    "collierville",
+    "corozal",
+    "cosmos",
+    "darden",
+    "forkland",
+    "hanson",
+    "hiteman",
+    "ihlen",
+    "klickitat",
+    "lakeville",
+    "leonardo",
+    "lindenwood",
+    "markleeville",
+    "marstons",
+    "mcdade",
+    "merom",
+    "mifflinburg",
+    "muleshoe",
+    "newfields",
+    "noxapater",
+    "onaga",
+    "pinesdale",
+    "pomaria",
+    "ranchester",
+    "shelbyville",
+    "stockman",
+    "tolstoy",
+    "uvalda",
+]
+
+TASKS = {
+    's': 'segment_semantic',
+}
+
+
+def parse_tasks(task_str):
+    tasks = []
+    for char in task_str:
+        tasks.append(TASKS[char])
+    return tasks
+
+
+def run():
+    parser = argparse.ArgumentParser(description='Extract')
+    parser.add_argument("--dir", type=str)
+
+    args = parser.parse_args()
+
+    for f in folders:
+        p = os.path.join(args.dir, "segment_semantic", f)
+        files = os.listdir(p)
+        for file in files:
+            if "segmentsemantic" in file:
+                old_file = os.path.join(p, file)
+                new_file = old_file.replace("segmentsemantic", "segment_semantic")
+                shutil.move(old_file, new_file)
+
+
+if __name__ == '__main__':
+    run()

+ 168 - 0
applications/mas/scripts/run.sh

@@ -0,0 +1,168 @@
+mkdir -p log
+mkdir -p log/mas
+now=$(date +"%Y%m%d_%H%M%S")
+
+root_dir=/mnt/lustre/$(whoami)
+project_dir=$root_dir/easyfl/applications/mas
+data_dir=$root_dir/datasets/taskonomy_datasets
+client_file=$project_dir/clients.txt
+
+export PYTHONPATH=$PYTHONPATH:${pwd}
+
+while [[ "$#" -gt 0 ]]; do
+    case $1 in
+        -p) partition="$2"; shift ;;
+        -t) tasks="$2"; shift ;;
+        -a) arch="$2"; shift ;;
+        -e) local_epoch="$2"; shift ;;
+        -k) clients_per_round="$2"; shift ;;
+        -b) batch_size="$2"; shift ;;
+        -r) rounds="$2"; shift ;;
+        -lr) lr="$2"; shift ;;
+        -lrt) lr_type="$2"; shift ;;
+        -te) test_every="$2"; shift ;;
+        -se) save_model_every="$2"; shift ;;
+        -gpus) gpus="$2"; shift ;;
+        -count) run_count="$2"; shift ;;
+        -port) dist_port="$2"; shift ;;
+        -tag) tag="$2"; shift ;;
+        -tag_step) tag_step="$2"; shift ;;
+        -what) what="$2"; shift ;;
+        -client_id) client_id="$2"; shift ;;
+        -agg_strategy) agg_strategy="$2"; shift ;;
+        -pretrained) pretrained="$2"; shift ;;
+        -pt) pretrained_tasks="$2"; shift ;;
+        -decoder) decoder="$2"; shift ;;
+        -half) half="$2"; shift ;;
+        *) echo "Unknown parameter passed: $1"; exit 1 ;;
+    esac
+    shift
+done
+
+if [ -z "${partition}" ]
+  then
+    partition=partition
+fi
+
+if [ -z "${tasks}" ]
+  then
+    tasks=""
+fi
+
+if [ -z "${arch}" ]
+  then
+    arch=xception # options: xception, resnet18
+fi
+
+if [ -z "${local_epoch}" ]
+  then
+    local_epoch=5
+fi
+
+if [ -z "${clients_per_round}" ]
+  then
+    clients_per_round=5
+fi
+
+if [ -z "${batch_size}" ]
+  then
+    batch_size=64
+fi
+
+if [ -z "${lr}" ]
+  then
+    lr=0.1
+fi
+
+if [ -z "${lr_type}" ]
+  then
+    lr_type=poly
+fi
+
+if [ -z "${rounds}" ]
+  then
+    rounds=100
+fi
+
+if [ -z "${test_every}" ]
+  then
+    test_every=1
+fi
+
+if [ -z "${save_model_every}" ]
+  then
+    save_model_every=1
+fi
+
+if [ -z "${gpus}" ]
+  then
+    gpus=1
+fi
+
+if [ -z "${dist_port}" ]
+  then
+    dist_port=23344
+fi
+
+# Whether use task affinity grouping (lookahead)
+if [ -z "${tag}" ]
+  then
+    tag='y'
+fi
+
+# Lookahead step
+if [ -z "${tag_step}" ]
+  then
+    tag_step=10
+fi
+
+if [ -z "${run_count}" ]
+  then
+    run_count=0
+fi
+
+if [ -z "${client_id}" ]
+  then
+    client_id='NA'
+fi
+
+if [ -z "${agg_strategy}" ]
+  then
+    agg_strategy='FedAvg'
+fi
+
+if [ -z "${pretrained_tasks}" ]
+  then
+    pretrained_tasks='sdnkt'
+fi
+
+use_pretrained='y'
+if [ -z "${pretrained}" ]
+  then
+    pretrained='n'
+    use_pretrained='n'
+    pretrained_tasks='n'
+fi
+
+if [ -z "${decoder}" ]
+  then
+    decoder='y'
+fi
+
+if [ -z "${half}" ]
+  then
+    half='n'
+fi
+
+job_name=mas-${tasks}-${arch}-b${batch_size}-${lr_type}lr${lr}-${agg_strategy}-tag-${tag}-${tag_step}-e${local_epoch}-n${clients_per_round}-r${rounds}-te${test_every}-se${save_model_every}-pretrained-${use_pretrained}-${pretrained_tasks}-${what}-${run_count}
+echo ${job_name}
+
+srun -u --partition=${partition} --job-name=${job_name} \
+    -n${gpus} --gres=gpu:${gpus} --ntasks-per-node=${gpus} \
+    python ${project_dir}/main.py --data_dir ${data_dir} --arch ${arch} --client_file ${client_file} \
+      --task_id ${job_name} --tasks ${tasks} --rotate_loss --batch_size ${batch_size} --lr ${lr} --lr_type ${lr_type} \
+      --local_epoch ${local_epoch} --clients_per_round ${clients_per_round} --rounds ${rounds} \
+      --test_every ${test_every} --save_model_every ${save_model_every} --random_selection --lookahead ${tag} --lookahead_step ${tag_step} \
+      --dist_port ${dist_port} --run_count ${run_count} --load_decoder ${decoder} --half ${half} \
+      --aggregation_strategy ${agg_strategy} --pretrained ${pretrained} --pretrained_tasks ${pretrained_tasks} \
+      --client_id ${client_id} 2>&1 | tee log/mas/${job_name}.log &

+ 168 - 0
applications/mas/scripts/run_pretrained.sh

@@ -0,0 +1,168 @@
+mkdir -p log
+mkdir -p log/mas
+now=$(date +"%Y%m%d_%H%M%S")
+
+root_dir=/mnt/lustre/$(whoami)
+project_dir=$root_dir/easyfl/applications/mas
+data_dir=$root_dir/datasets/taskonomy_datasets
+client_file=$project_dir/clients.txt
+
+export PYTHONPATH=$PYTHONPATH:${pwd}
+
+while [[ "$#" -gt 0 ]]; do
+    case $1 in
+        -p) partition="$2"; shift ;;
+        -t) tasks="$2"; shift ;;
+        -a) arch="$2"; shift ;;
+        -e) local_epoch="$2"; shift ;;
+        -k) clients_per_round="$2"; shift ;;
+        -b) batch_size="$2"; shift ;;
+        -r) rounds="$2"; shift ;;
+        -lr) lr="$2"; shift ;;
+        -lrt) lr_type="$2"; shift ;;
+        -te) test_every="$2"; shift ;;
+        -se) save_model_every="$2"; shift ;;
+        -gpus) gpus="$2"; shift ;;
+        -count) run_count="$2"; shift ;;
+        -port) dist_port="$2"; shift ;;
+        -tag) tag="$2"; shift ;;
+        -tag_step) tag_step="$2"; shift ;;
+        -what) what="$2"; shift ;;
+        -client_id) client_id="$2"; shift ;;
+        -agg_strategy) agg_strategy="$2"; shift ;;
+        -pretrained) pretrained="$2"; shift ;;
+        -pt) pretrained_tasks="$2"; shift ;;
+        -decoder) decoder="$2"; shift ;;
+        -half) half="$2"; shift ;;
+        *) echo "Unknown parameter passed: $1"; exit 1 ;;
+    esac
+    shift
+done
+
+if [ -z "${partition}" ]
+  then
+    partition=Sensetime
+fi
+
+if [ -z "${tasks}" ]
+  then
+    tasks=""
+fi
+
+if [ -z "${arch}" ]
+  then
+    arch=xception # options: xception, resnet18
+fi
+
+if [ -z "${local_epoch}" ]
+  then
+    local_epoch=5
+fi
+
+if [ -z "${clients_per_round}" ]
+  then
+    clients_per_round=5
+fi
+
+if [ -z "${batch_size}" ]
+  then
+    batch_size=64
+fi
+
+if [ -z "${lr}" ]
+  then
+    lr=0.1
+fi
+
+if [ -z "${lr_type}" ]
+  then
+    lr_type=poly
+fi
+
+if [ -z "${rounds}" ]
+  then
+    rounds=100
+fi
+
+if [ -z "${test_every}" ]
+  then
+    test_every=1
+fi
+
+if [ -z "${save_model_every}" ]
+  then
+    save_model_every=1
+fi
+
+if [ -z "${gpus}" ]
+  then
+    gpus=1
+fi
+
+if [ -z "${dist_port}" ]
+  then
+    dist_port=23344
+fi
+
+# Whether use task affinity grouping (lookahead)
+if [ -z "${tag}" ]
+  then
+    tag='y'
+fi
+
+# Lookahead step
+if [ -z "${tag_step}" ]
+  then
+    tag_step=10
+fi
+
+if [ -z "${run_count}" ]
+  then
+    run_count=0
+fi
+
+if [ -z "${client_id}" ]
+  then
+    client_id='NA'
+fi
+
+if [ -z "${agg_strategy}" ]
+  then
+    agg_strategy='FedAvg'
+fi
+
+if [ -z "${pretrained_tasks}" ]
+  then
+    pretrained_tasks='sdnkt'
+fi
+
+use_pretrained='y'
+if [ -z "${pretrained}" ]
+  then
+    pretrained='n'
+    use_pretrained='n'
+    pretrained_tasks='n'
+fi
+
+if [ -z "${decoder}" ]
+  then
+    decoder='y'
+fi
+
+if [ -z "${half}" ]
+  then
+    half='n'
+fi
+
+job_name=mtfl-${tasks}-${arch}-b${batch_size}-${lr_type}lr${lr}-${agg_strategy}-tag-${tag}-${tag_step}-e${local_epoch}-n${clients_per_round}-r${rounds}-te${test_every}-se${save_model_every}-pretrained-${use_pretrained}-${pretrained_tasks}-${what}-${run_count}
+echo ${job_name}
+
+srun -u --partition=${partition} --job-name=${job_name} \
+    -n${gpus} --gres=gpu:${gpus} --ntasks-per-node=${gpus} \
+    python ${project_dir}/main.py --data_dir ${data_dir} --arch ${arch} --client_file ${client_file} \
+      --task_id ${job_name} --tasks ${tasks} --rotate_loss --batch_size ${batch_size} --lr ${lr} --lr_type ${lr_type} \
+      --local_epoch ${local_epoch} --clients_per_round ${clients_per_round} --rounds ${rounds} \
+      --test_every ${test_every} --save_model_every ${save_model_every} --random_selection --lookahead ${tag} --lookahead_step ${tag_step} \
+      --dist_port ${dist_port} --run_count ${run_count} --load_decoder ${decoder} --half ${half} \
+      --aggregation_strategy ${agg_strategy} --pretrained ${pretrained} --pretrained_tasks ${pretrained_tasks} \
+      --client_id ${client_id} 2>&1 | tee log/mas/${pretrained_tasks}/${job_name}.log &

+ 195 - 0
applications/mas/server.py

@@ -0,0 +1,195 @@
+import copy
+import logging
+import os
+import shutil
+import time
+from collections import defaultdict
+
+import torch
+
+from dataset import DataPrefetcher
+from losses import get_losses
+from utils import AverageMeter
+from easyfl.distributed.distributed import CPU
+from easyfl.server.base import BaseServer, MODEL, DATA_SIZE
+from easyfl.tracking import metric
+
+logger = logging.getLogger(__name__)
+
+
+class MASServer(BaseServer):
+    def __init__(self, conf, test_data=None, val_data=None, is_remote=False, local_port=22999):
+        super(MASServer, self).__init__(conf, test_data, val_data, is_remote, local_port)
+        self.train_loader = None
+        self.test_loader = None
+
+        self._progress_table = []
+        self._stats = []
+        self._loss_history = []
+        self._current_loss = 9e9
+        self._best_loss = 9e9
+        self._best_model = None
+        self._client_models = []
+
+    def aggregation(self):
+        uploaded_content = self.get_client_uploads()
+        models = list(uploaded_content[MODEL].values())
+        weights = list(uploaded_content[DATA_SIZE].values())
+
+        # Cache client models for saving
+        self._client_models = [copy.deepcopy(m).cpu() for m in models]
+
+        # Aggregation
+        model = self.aggregate(models, weights)
+
+        self.set_model(model, load_dict=True)
+
+    def test_in_server(self, device=CPU):
+        # Validation
+        val_loader = self.val_data.loader(
+            batch_size=max(self.conf.server.batch_size // 2, 1),
+            shuffle=False,
+            seed=self.conf.seed)
+
+        test_results, stats, progress = self.test_fn(val_loader, self._model, device)
+        self._current_loss = float(stats['Loss'])
+        self._stats.append(stats)
+        self._loss_history.append(self._current_loss)
+        self._progress_table.append(progress)
+        logger.info(f"Validation statistics: {stats}")
+
+        # Test
+        if self._current_round == self.conf.server.rounds - 1:
+            test_loader = self.test_data.loader(
+                batch_size=max(self.conf.server.batch_size // 2, 1),
+                shuffle=False,
+                seed=self.conf.seed)
+            _, stats, progress_table = self.test_fn(test_loader, self._model, device)
+            logger.info(f"Testing statistics of last round: {stats}")
+
+            if self._current_loss <= self._best_loss:
+                logger.info(f"Last round {self._current_round} is the best round")
+            else:
+                _, stats, progress_table = self.test_fn(test_loader, self._best_model, device)
+                logger.info(f"Testing statistics of best model: {stats}")
+
+        return test_results
+
+    def test_fn(self, loader, model, device=CPU):
+        model.eval()
+        model.to(device)
+
+        criteria = get_losses(self.conf.client.task_str, self.conf.client.rotate_loss, self.conf.client.task_weights)
+
+        average_meters = defaultdict(AverageMeter)
+        epoch_start_time = time.time()
+        batch_num = 0
+        num_data_points = len(loader)
+
+        prefetcher = DataPrefetcher(loader, device)
+        # torch.cuda.empty_cache()
+
+        with torch.no_grad():
+            for i in range(len(loader)):
+                input, target = prefetcher.next()
+
+                if batch_num == 0:
+                    epoch_start_time2 = time.time()
+
+                output = model(input)
+
+                loss_dict = {}
+                for c_name, criterion_fn in criteria.items():
+                    loss_dict[c_name] = criterion_fn(output, target)
+
+                batch_num = i + 1
+
+                for name, value in loss_dict.items():
+                    try:
+                        average_meters[name].update(value.data)
+                    except:
+                        average_meters[name].update(value)
+                eta = ((time.time() - epoch_start_time2) / (batch_num + .2)) * (len(loader) - batch_num)
+                to_print = {
+                    f'#/{num_data_points}': '{0}'.format(batch_num),
+                    'eta': '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(eta))))
+                }
+                for name in criteria.keys():
+                    meter = average_meters[name]
+                    to_print[name] = '{meter.avg:.4f}'.format(meter=meter)
+
+
+        epoch_time = time.time() - epoch_start_time
+
+        stats = {'batches': len(loader), 'epoch_time': epoch_time}
+
+        for name in criteria.keys():
+            meter = average_meters[name]
+            stats[name] = meter.avg
+
+        to_print['eta'] = '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(epoch_time))))
+        torch.cuda.empty_cache()
+
+        test_results = {
+            metric.TEST_ACCURACY: 0,
+            metric.TEST_LOSS: float(stats['Loss']),
+        }
+
+        return test_results, stats, [to_print]
+
+    def save_model(self):
+        if self._do_every(self.conf.server.save_model_every, self._current_round, self.conf.server.rounds) and \
+                self.is_primary_server():
+            save_path = self.conf.server.save_model_path
+            if save_path == "":
+                save_path = os.path.join(os.getcwd(), "saved_models", "mas", self.conf.task_id)
+            os.makedirs(save_path, exist_ok=True)
+            if self.conf.server.save_model_every == 1:
+                save_filename = f"{self.conf.task_id}_checkpoint.pth.tar"
+            else:
+                save_filename = f"{self.conf.task_id}_r_{self._current_round}_checkpoint.pth.tar"
+            # save_path = os.path.join(save_path, f"{self.conf.task_id}_r_{self._current_round}_checkpoint.pth.tar")
+
+            is_best = self._current_loss < self._best_loss
+            self._best_loss = min(self._current_loss, self._best_loss)
+
+            try:
+                checkpoint = {
+                    'round': self._current_round,
+                    'info': {'machine': self.conf.distributed.init_method, 'GPUS': self.conf.gpu},
+                    'args': self.conf,
+                    'arch': self.conf.arch,
+                    'state_dict': self._model.cpu().state_dict(),
+                    'best_loss': self._best_loss,
+                    'progress_table': self._progress_table,
+                    'stats': self._stats,
+                    'loss_history': self._loss_history,
+                    'code_archive': self.get_code_archive(),
+                    'client_models': [m.cpu().state_dict() for m in self._client_models]
+                }
+                self.save_checkpoint(checkpoint, False, save_path, save_filename)
+
+                if is_best:
+                    logger.info(f"Best validation loss at round {self._current_round}: {self._best_loss}")
+                    self._best_model = copy.deepcopy(self._model)
+                    self.save_checkpoint(None, True, save_path, save_filename)
+                self.print_("Checkpoint saved at {}".format(save_path))
+            except:
+                self.print_('Save checkpoint failed...')
+
+
+    def save_checkpoint(self, state, is_best, directory='', filename='checkpoint.pth.tar'):
+        path = os.path.join(directory, filename)
+        if is_best:
+            best_path = os.path.join(directory, f"best_{self.conf.task_id}_checkpoint.pth.tar")
+            shutil.copyfile(path, best_path)
+        else:
+            torch.save(state, path)
+
+    def get_code_archive(self):
+        file_contents = {}
+        for i in os.listdir('.'):
+            if i[-3:] == '.py':
+                with open(i, 'r') as file:
+                    file_contents[i] = file.read()
+        return file_contents

+ 168 - 0
applications/mas/split.py

@@ -0,0 +1,168 @@
+import argparse
+import collections
+import copy
+import functools
+import json
+import operator
+import re
+from pprint import pprint
+
+import matplotlib
+import matplotlib.pyplot as plt
+matplotlib.rcParams['text.usetex'] = True
+plt.rcParams.update({'font.size': 30})
+import numpy as np
+
+import network_selection
+
+
+MAPPING = {
+    'ss_l': 's',
+    'depth_l': 'd',
+    'norm_l': 'n',
+    'key_l': 'k',
+    'edge2d_l': 't',
+    'edge_l': 'e',
+    'shade_l': 'r',
+    'rgb_l': 'a',
+    'pc_l': 'c',
+}
+
+COLOR_MAP = {
+    'ss_l': 'tab:blue',
+    'depth_l': 'tab:orange',
+    'norm_l': 'tab:green',
+    'key_l': 'tab:red',
+    'edge2d_l': 'tab:purple',
+    'edge_l': 'tab:brown',
+    'shade_l': 'tab:pink',
+    'rgb_l': 'tab:gray',
+    'pc_l': 'tab:olive',
+}
+
+
+class Affinity:
+    def __init__(self, args=None):
+        self.affinities = {}
+        self.args = args
+        self.task_overlap = args.task_overlap
+        self.split = args.split
+
+    def add(self, round_id, client_id, affinity):
+        if self.args.preprocess:
+            affinity = self.preprocess_affinity(affinity)
+
+        for scores in affinity.values():
+            if isinstance(scores, list) and scores[0]['ss_l'] == 0.0:
+                return
+            else:
+                break
+
+        if round_id not in self.affinities:
+            self.affinities[round_id] = {client_id: affinity}
+        else:
+            self.affinities[round_id][client_id] = affinity
+
+    def get_round_affinities(self, round_id):
+        return list(self.affinities[round_id].values())
+
+    def average_affinities(self, affinities):
+        result = copy.deepcopy(affinities[0])
+        for task, affinity in result.items():
+            for target_task, score in affinity.items():
+                total = score
+                for a in affinities[1:]:
+                    total += a[task][target_task]
+                result[task][target_task] = total / len(affinities)
+        return result
+
+    def average_affinity_of_clients(self, max_round=100):
+        affinities = {}
+        for round_id, affinity in self.affinities.items():
+            if round_id >= max_round:
+                continue
+            result = self.average_affinities(list(affinity.values()))
+            affinities[round_id] = result
+        return affinities
+
+    def average_affinity_of_rounds(self, max_round=100):
+        affinities = self.average_affinity_of_clients(max_round)
+        return self.average_affinities(list(affinities.values()))
+
+    def preprocess_affinity(self, affinity):
+        for task, scores in affinity.items():
+            result = dict(functools.reduce(operator.add, map(collections.Counter, scores)))
+            affinity[task] = result
+        return affinity
+
+    def network_selection(self, rounds, specific_round=False):
+        results = {}
+        # Network selection of specific round
+        if specific_round:
+            for round_id in rounds:
+                round_affinities = self.get_round_affinities(round_id)
+                # Network selection of average
+                averaged_affinity = self.average_affinities(round_affinities)
+                result = network_selection.task_grouping(averaged_affinity, task_overlap=self.task_overlap,
+                                                         split=self.split)
+                results[round_id] = {"average": result}
+                # pprint(averaged_affinity)
+                if not self.args.average_only:
+                    for client, a in self.affinities[round_id].items():
+                        result = network_selection.task_grouping(a, task_overlap=self.task_overlap, split=self.split)
+                        results[round_id][client] = result
+        # Average task affinity of all rounds
+        for round_id in rounds:
+            affinities = self.average_affinity_of_rounds(round_id)
+            results[f"average_{round_id}"] = network_selection.task_grouping(affinities, task_overlap=self.task_overlap,
+                                                                             split=self.split)
+        # Convert string formats from loss to single letter
+        return results
+
+
+def extract_task_affinity(line):
+    r = re.search(r'[\d\w\-\[\]\:\ ,]* Round (\d+) - Client (\w+) transference: (\{[\{\}\[\]\'\-\_\d\w\: .,]*\}\n)',
+                  line)
+
+    if not r:
+        return
+    return r.groups()
+
+
+def run(args):
+    A = Affinity(args)
+    with open(args.filename, 'r') as f:
+        for line in f:
+            data = extract_task_affinity(line)
+            if not data:
+                continue
+            round_id, client_id, affinity = data
+            round_id = int(round_id)
+            affinity = affinity.replace("'", "\"")
+            affinity = json.loads(affinity)
+            A.add(round_id, client_id, affinity)
+        else:
+            results = A.network_selection(args.rounds)
+            results_str = json.dumps(results)
+            for loss_name, char in MAPPING.items():
+                results_str = results_str.replace(loss_name, char)
+            results = json.loads(results_str)
+            pprint(results)
+
+
+def construct_analyze_parser(parser):
+    parser.add_argument('-f', '--filename', type=str, metavar='PATH', default="./train.log")
+    parser.add_argument('-s', '--split', type=int, default=3)
+    parser.add_argument('-o', '--task_overlap', action='store_true')
+    parser.add_argument('-p', '--preprocess', action='store_true')
+    parser.add_argument('-a', '--average_only', action='store_true')
+    parser.add_argument('-r', '--rounds', nargs="*", default=[10], type=int)
+    return parser
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Split')
+    parser = construct_analyze_parser(parser)
+    args = parser.parse_args()
+    print("args:", args)
+    run(args)

+ 288 - 0
applications/mas/taskonomy-tiny-data.txt

@@ -0,0 +1,288 @@
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/allensville_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/beechwood_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/benevolence_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/coffeen_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/collierville_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/corozal_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/cosmos_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/darden_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/forkland_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/hanson_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/hiteman_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/ihlen_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/klickitat_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/lakeville_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/leonardo_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/lindenwood_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/markleeville_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/marstons_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/mcdade_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/merom_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/mifflinburg_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/muleshoe_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/newfields_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/noxapater_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/onaga_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/pinesdale_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/pomaria_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/ranchester_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/shelbyville_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/stockman_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/tolstoy_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/depth_zbuffer/uvalda_depth_zbuffer.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/allensville_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/beechwood_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/benevolence_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/coffeen_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/collierville_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/corozal_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/cosmos_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/darden_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/forkland_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/hanson_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/hiteman_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/ihlen_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/klickitat_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/lakeville_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/leonardo_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/lindenwood_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/markleeville_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/marstons_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/mcdade_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/merom_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/mifflinburg_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/muleshoe_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/newfields_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/noxapater_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/onaga_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/pinesdale_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/pomaria_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/ranchester_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/shelbyville_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/stockman_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/tolstoy_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_occlusion/uvalda_edge_occlusion.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/allensville_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/beechwood_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/benevolence_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/coffeen_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/collierville_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/corozal_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/cosmos_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/darden_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/forkland_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/hanson_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/hiteman_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/ihlen_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/klickitat_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/lakeville_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/leonardo_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/lindenwood_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/markleeville_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/marstons_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/mcdade_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/merom_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/mifflinburg_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/muleshoe_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/newfields_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/noxapater_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/onaga_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/pinesdale_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/pomaria_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/ranchester_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/shelbyville_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/stockman_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/tolstoy_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/edge_texture/uvalda_edge_texture.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/allensville_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/beechwood_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/benevolence_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/coffeen_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/collierville_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/corozal_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/cosmos_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/darden_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/forkland_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/hanson_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/hiteman_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/ihlen_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/klickitat_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/lakeville_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/leonardo_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/lindenwood_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/markleeville_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/marstons_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/mcdade_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/merom_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/mifflinburg_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/muleshoe_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/newfields_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/noxapater_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/onaga_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/pinesdale_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/pomaria_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/ranchester_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/shelbyville_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/stockman_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/tolstoy_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/keypoints2d/uvalda_keypoints2d.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/allensville_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/beechwood_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/benevolence_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/coffeen_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/collierville_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/corozal_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/cosmos_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/darden_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/forkland_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/hanson_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/hiteman_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/ihlen_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/klickitat_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/lakeville_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/leonardo_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/lindenwood_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/markleeville_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/marstons_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/mcdade_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/merom_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/mifflinburg_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/muleshoe_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/newfields_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/noxapater_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/onaga_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/pinesdale_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/pomaria_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/ranchester_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/shelbyville_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/stockman_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/tolstoy_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/normal/uvalda_normal.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/allensville_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/beechwood_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/benevolence_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/coffeen_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/collierville_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/corozal_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/cosmos_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/darden_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/forkland_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/hanson_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/hiteman_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/ihlen_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/klickitat_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/lakeville_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/leonardo_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/lindenwood_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/markleeville_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/marstons_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/mcdade_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/merom_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/mifflinburg_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/muleshoe_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/newfields_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/noxapater_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/onaga_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/pinesdale_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/pomaria_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/ranchester_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/shelbyville_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/stockman_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/tolstoy_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/principal_curvature/uvalda_principal_curvature.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/allensville_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/beechwood_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/benevolence_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/coffeen_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/collierville_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/corozal_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/cosmos_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/darden_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/forkland_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/hanson_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/hiteman_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/ihlen_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/klickitat_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/lakeville_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/leonardo_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/lindenwood_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/markleeville_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/marstons_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/mcdade_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/merom_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/mifflinburg_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/muleshoe_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/newfields_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/noxapater_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/onaga_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/pinesdale_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/pomaria_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/ranchester_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/shelbyville_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/stockman_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/tolstoy_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/reshading/uvalda_reshading.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/allensville_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/beechwood_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/benevolence_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/coffeen_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/collierville_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/corozal_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/cosmos_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/darden_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/forkland_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/hanson_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/hiteman_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/ihlen_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/klickitat_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/lakeville_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/leonardo_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/lindenwood_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/markleeville_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/marstons_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/mcdade_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/merom_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/mifflinburg_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/muleshoe_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/newfields_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/noxapater_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/onaga_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/pinesdale_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/pomaria_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/ranchester_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/shelbyville_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/stockman_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/tolstoy_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/rgb/uvalda_rgb.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/allensville_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/beechwood_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/benevolence_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/coffeen_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/collierville_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/corozal_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/cosmos_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/darden_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/forkland_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/hanson_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/hiteman_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/ihlen_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/klickitat_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/lakeville_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/leonardo_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/lindenwood_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/markleeville_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/marstons_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/mcdade_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/merom_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/mifflinburg_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/muleshoe_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/newfields_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/noxapater_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/onaga_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/pinesdale_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/pomaria_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/ranchester_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/shelbyville_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/stockman_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/tolstoy_segment_semantic.tar
+http://downloads.cs.stanford.edu/downloads/taskonomy_data/segment_semantic/uvalda_segment_semantic.tar

+ 364 - 0
applications/mas/trainer.py

@@ -0,0 +1,364 @@
+import copy
+import logging
+import time
+from collections import defaultdict
+
+import scipy.stats
+import torch
+
+from utils import AverageMeter
+
+from easyfl.distributed.distributed import CPU
+
+logger = logging.getLogger(__name__)
+
+LR_POLY = "poly"
+LR_CUSTOM = "custom"
+
+
+class Trainer:
+    def __init__(self, cid, conf, train_loader, model, optimizer, criteria, device=CPU, checkpoint=None):
+        self.cid = cid
+        self.conf = conf
+        self.train_loader = train_loader
+        self.model = model
+        self.optimizer = optimizer
+        self.criteria = criteria
+        self.loss_keys = list(self.criteria.keys())[1:]
+        self.device = device
+        # self.args = args
+
+        self.progress_table = []
+        # self.best_loss = 9e9
+        self.stats = []
+        self.start_epoch = 0
+        self.loss_history = []
+        self.encoder_trainable = None
+        # self.code_archive = self.get_code_archive()
+        if checkpoint:
+            if 'progress_table' in checkpoint:
+                self.progress_table = checkpoint['progress_table']
+            if 'epoch' in checkpoint:
+                self.start_epoch = checkpoint['epoch'] + 1
+            # if 'best_loss' in checkpoint:
+            #     self.best_loss = checkpoint['best_loss']
+            if 'stats' in checkpoint:
+                self.stats = checkpoint['stats']
+            if 'loss_history' in checkpoint:
+                self.loss_history = checkpoint['loss_history']
+
+        self.lr0 = self.conf.optimizer.lr
+        self.lr = self.lr0
+
+        self.ticks = 0
+        self.last_tick = 0
+        # self.loss_tracking_window = self.conf.loss_tracking_window_initial
+
+        # estimated loss tracking window for each client, based on their dataset size, compared with original implementation.
+        if self.conf.optimizer.lr_type == LR_CUSTOM:
+            self.loss_tracking_window = len(train_loader) * self.conf.batch_size / 8
+            self.maximum_loss_tracking_window = len(train_loader) * self.conf.batch_size / 2
+            logger.info(
+                f"Client {self.cid}: loss_tracking_window: {self.loss_tracking_window}, maximum_loss_tracking_window: {self.maximum_loss_tracking_window}")
+
+    def train(self):
+        self.encoder_trainable = [
+            p for p in self.model.encoder.parameters() if p.requires_grad
+        ]
+
+        transference = {combined_task: [] for combined_task in self.loss_keys}
+        for self.epoch in range(self.start_epoch, self.conf.local_epoch):
+            current_learning_rate = get_average_learning_rate(self.optimizer)
+            # Stop training when learning rate is smaller than minimum learning rate
+            if current_learning_rate < self.conf.minimum_learning_rate:
+                logger.info(f"Client {self.cid} stop local training because lr too small, lr: {current_learning_rate}.")
+                break
+            # Train for one epoch
+            train_string, train_stats, epoch_transference = self.train_epoch()
+            self.progress_table.append(train_string)
+            self.stats.append(train_stats)
+
+            for combined_task in self.loss_keys:
+                transference[combined_task].append(epoch_transference[combined_task])
+
+            # # evaluate on validation set
+            # progress_string = train_string
+            # loss, progress_string, val_stats = self.validate(progress_string)
+            #
+            # self.progress_table.append(progress_string)
+            # self.stats.append((train_stats, val_stats))
+        # Clean up to save memory
+        del self.encoder_trainable
+        self.encoder_trainable = None
+        return transference
+
+    def train_epoch(self):
+        average_meters = defaultdict(AverageMeter)
+        display_values = []
+        for name, func in self.criteria.items():
+            display_values.append(name)
+
+        # Switch to train mode
+        self.model.train()
+
+        epoch_start_time = time.time()
+        epoch_start_time2 = time.time()
+
+        batch_num = 0
+        num_data_points = len(self.train_loader) // self.conf.virtual_batch_multiplier
+        if num_data_points > 10000:
+            num_data_points = num_data_points // 5
+
+        starting_learning_rate = get_average_learning_rate(self.optimizer)
+
+        # Initialize task affinity dictionary
+        epoch_transference = {}
+        for combined_task in self.loss_keys:
+            epoch_transference[combined_task] = {}
+            for recipient_task in self.loss_keys:
+                epoch_transference[combined_task][recipient_task] = 0.
+
+        for i, (input, target) in enumerate(self.train_loader):
+            input = input.to(self.device)
+            for n, t in target.items():
+                target[n] = t.to(self.device)
+
+            # self.percent = batch_num / num_data_points
+            if i == 0:
+                epoch_start_time2 = time.time()
+
+            loss_dict = None
+            loss = 0
+            
+            self.optimizer.zero_grad()
+
+            _train_batch_lookahead = self.conf.lookahead == 'y' and i % self.conf.lookahead_step == 0
+
+            # Accumulate gradients over multiple runs of input
+            for _ in range(self.conf.virtual_batch_multiplier):
+                data_start = time.time()
+                average_meters['data_time'].update(time.time() - data_start)
+                # lookahead step 10
+                if _train_batch_lookahead:
+                    loss_dict2, loss2, batch_transference = self.train_batch_lookahead(input, target)
+                else:
+                    loss_dict2, loss2, batch_transference = self.train_batch(input, target)
+                loss += loss2
+                if loss_dict is None:
+                    loss_dict = loss_dict2
+                else:
+                    for key, value in loss_dict2.items():
+                        loss_dict[key] += value
+
+            if _train_batch_lookahead:
+                for combined_task in self.loss_keys:
+                    for recipient_task in self.loss_keys:
+                        epoch_transference[combined_task][recipient_task] += (
+                                batch_transference[combined_task][recipient_task] / (len(self.train_loader) / self.conf.lookahead_step))
+
+            # divide by the number of accumulations
+            loss /= self.conf.virtual_batch_multiplier
+            for key, value in loss_dict.items():
+                loss_dict[key] = value / self.conf.virtual_batch_multiplier
+
+            # do the weight updates and set gradients back to zero
+            self.optimizer.step()
+
+            self.loss_history.append(float(loss))
+            ttest_p, z_diff = self.learning_rate_schedule()
+
+            for name, value in loss_dict.items():
+                try:
+                    average_meters[name].update(value.data)
+                except:
+                    average_meters[name].update(value)
+
+            elapsed_time_for_epoch = (time.time() - epoch_start_time2)
+            eta = (elapsed_time_for_epoch / (batch_num + .2)) * (num_data_points - batch_num)
+            if eta >= 24 * 3600:
+                eta = 24 * 3600 - 1
+
+            batch_num += 1
+
+            current_learning_rate = get_average_learning_rate(self.optimizer)
+            to_print = {
+                'ep': f'{self.epoch}:',
+                f'#/{num_data_points}': f'{batch_num}',
+                'lr': '{0:0.3g}-{1:0.3g}'.format(starting_learning_rate, current_learning_rate),
+                'eta': '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(eta)))),
+                'd%': '{0:0.2g}'.format(100 * average_meters['data_time'].sum / elapsed_time_for_epoch)
+            }
+            for name in display_values:
+                meter = average_meters[name]
+                to_print[name] = '{meter.avg:.4f}'.format(meter=meter)
+            if batch_num < num_data_points - 1:
+                to_print['ETA'] = '{0}'.format(
+                    time.strftime("%H:%M:%S", time.gmtime(int(eta + elapsed_time_for_epoch))))
+                to_print['ttest'] = '{0:0.3g},{1:0.3g}'.format(z_diff, ttest_p)
+
+
+        epoch_time = time.time() - epoch_start_time
+        stats = {
+            'batches': num_data_points,
+            'learning_rate': current_learning_rate,
+            'epoch_time': epoch_time,
+        }
+        for name in display_values:
+            meter = average_meters[name]
+            stats[name] = meter.avg
+
+        to_print['eta'] = '{0}'.format(time.strftime("%H:%M:%S", time.gmtime(int(epoch_time))))
+
+        logger.info(f"Client {self.cid} training statistics: {stats}")
+        return [to_print], stats, epoch_transference
+
+    def train_batch(self, x, target):
+        loss_dict = {}
+        x = x.float()
+        output = self.model(x)
+        first_loss = None
+        for c_name, criterion_fn in self.criteria.items():
+            if first_loss is None:
+                first_loss = c_name
+            loss_dict[c_name] = criterion_fn(output, target)
+
+        loss = loss_dict[first_loss].clone()
+        loss = loss / self.conf.virtual_batch_multiplier
+
+        if self.conf.fp16:
+            from apex import amp
+            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+                scaled_loss.backward()
+        else:
+            loss.backward()
+
+        return loss_dict, loss, {}
+
+    def train_batch_lookahead(self, x, target):
+        loss_dict = {}
+        x = x.float()
+        output = self.model(x)
+        first_loss = None
+        for c_name, criterion_fun in self.criteria.items():
+            if first_loss is None:
+                first_loss = c_name
+            loss_dict[c_name] = criterion_fun(output, target)
+
+        loss = loss_dict[first_loss].clone()
+
+        transference = {}
+        for combined_task in self.loss_keys:
+            transference[combined_task] = {}
+        if self.conf.fp16:
+            from apex import amp
+            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+                scaled_loss.backward()
+        else:
+            for combined_task in self.loss_keys:
+                preds = self.lookahead(x, loss_dict[combined_task])
+                first_loss = None
+                for c_name, criterion_fun in self.criteria.items():
+                    if first_loss is None:
+                        first_loss = c_name
+                    transference[combined_task][c_name] = (
+                            (1.0 - (criterion_fun(preds, target) / loss_dict[c_name])) /
+                            self.optimizer.state_dict()['param_groups'][0]['lr']
+                    ).detach().cpu().numpy()
+            self.optimizer.zero_grad()
+            loss.backward()
+
+        # Want to invert the dictionary so it's source_task => gradients on source task.
+        rev_transference = {source: {} for source in transference}
+        for grad_task in transference:
+            for source in transference[grad_task]:
+                if 'Loss' in source:
+                    continue
+                rev_transference[source][grad_task] = transference[grad_task][
+                    source]
+        return loss_dict, loss, copy.deepcopy(rev_transference)
+
+    def lookahead(self, x, loss):
+        self.optimizer.zero_grad()
+        shared_params = self.encoder_trainable
+        init_weights = [param.data for param in shared_params]
+        grads = torch.autograd.grad(loss, shared_params, retain_graph=True)
+
+        # Compute updated params for the forward pass: SGD w/ 0.9 momentum + 1e-4 weight decay.
+        opt_state = self.optimizer.state_dict()['param_groups'][0]
+        weight_decay = opt_state['weight_decay']
+
+        for param, g, param_id in zip(shared_params, grads, opt_state['params']):
+            grad = g.clone()
+            grad += param * weight_decay
+            if 'momentum_buffer' not in opt_state:
+                mom_buf = grad
+            else:
+                mom_buf = opt_state['momentum_buffer']
+                mom_buf = mom_buf * opt_state['momentum'] + grad
+            param.data = param.data - opt_state['lr'] * mom_buf
+
+            grad = grad.cpu()
+            del grad
+
+        with torch.no_grad():
+            output = self.model(x)
+
+        for param, init_weight in zip(shared_params, init_weights):
+            param.data = init_weight
+        return output
+
+    def learning_rate_schedule(self):
+        # don't process learning rate if the schedule type is poly, which adjusted before training.
+        if self.conf.optimizer.lr_type == LR_POLY:
+            return 0, 0
+
+        # don't reduce learning rate until the second epoch has ended.
+        if self.epoch < 2:
+            return 0, 0
+
+        ttest_p = 0
+        z_diff = 0
+
+        wind = self.loss_tracking_window // (self.conf.batch_size * self.conf.virtual_batch_multiplier)
+        if len(self.loss_history) - self.last_tick > wind:
+            a = self.loss_history[-wind:-wind * 5 // 8]
+            b = self.loss_history[-wind * 3 // 8:]
+            # remove outliers
+            a = sorted(a)
+            b = sorted(b)
+            a = a[int(len(a) * .05):int(len(a) * .95)]
+            b = b[int(len(b) * .05):int(len(b) * .95)]
+            length_ = min(len(a), len(b))
+            a = a[:length_]
+            b = b[:length_]
+            z_diff, ttest_p = scipy.stats.ttest_rel(a, b, nan_policy='omit')
+
+            if z_diff < 0 or ttest_p > .99:
+                self.ticks += 1
+                self.last_tick = len(self.loss_history)
+                self.adjust_learning_rate()
+                self.loss_tracking_window = min(self.maximum_loss_tracking_window, self.loss_tracking_window * 2)
+        return ttest_p, z_diff
+
+    def adjust_learning_rate(self):
+        self.lr = self.lr0 * (0.50 ** self.ticks)
+        self.set_learning_rate(self.lr)
+
+    def set_learning_rate(self, lr):
+        for param_group in self.optimizer.param_groups:
+            param_group['lr'] = lr
+
+    def update(self, model, optimizer, device):
+        self.model = model
+        self.optimizer = optimizer
+        self.device = device
+
+
+def get_average_learning_rate(optimizer):
+    try:
+        return optimizer.learning_rate
+    except:
+        s = 0
+        for param_group in optimizer.param_groups:
+            s += param_group['lr']
+        return s / len(optimizer.param_groups)

+ 97 - 0
applications/mas/utils.py

@@ -0,0 +1,97 @@
+import logging
+from collections import defaultdict
+
+import numpy as np
+
+logger = logging.getLogger(__name__)
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.std = 0
+        self.sum = 0
+        self.sumsq = 0
+        self.count = 0
+        self.lst = []
+
+    def update(self, val, n=1):
+        self.val = float(val)
+        self.sum += float(val) * n
+        # self.sumsq += float(val)**2
+        self.count += n
+        self.avg = self.sum / self.count
+        self.lst.append(self.val)
+        self.std = np.std(self.lst)
+
+
+class ProgressTable:
+    def __init__(self, table_list):
+        if len(table_list) == 0:
+            print()
+            return
+        self.lens = defaultdict(int)
+        self.table_list = table_list
+        self.construct(table_list)
+
+    def construct(self, table_list):
+        self.lens = defaultdict(int)
+        self.table_list = table_list
+        for i in table_list:
+            for ii, to_print in enumerate(i):
+                for title, val in to_print.items():
+                    self.lens[(title, ii)] = max(self.lens[(title, ii)], max(len(title), len(val)))
+
+    def print_table_header(self):
+        for ii, to_print in enumerate(self.table_list[0]):
+            for title, val in to_print.items():
+                print('{0:^{1}}'.format(title, self.lens[(title, ii)]), end=" ")
+
+    def print_table_content(self):
+        for i in self.table_list:
+            print()
+            for ii, to_print in enumerate(i):
+                for title, val in to_print.items():
+                    print('{0:^{1}}'.format(val, self.lens[(title, ii)]), end=" ", flush=True)
+
+    def print_all_table(self):
+        self.print_table_header()
+        self.print_table_content()
+
+    def print_table(self, header_condition, content_condition):
+        if header_condition:
+            self.print_table_header()
+        if content_condition:
+            self.print_table_content()
+
+    def update_table_list(self, table_list):
+        self.construct(table_list)
+
+
+def print_table(table_list):
+    if len(table_list) == 0:
+        print()
+        return
+
+    lens = defaultdict(int)
+    for i in table_list:
+        for ii, to_print in enumerate(i):
+            for title, val in to_print.items():
+                lens[(title, ii)] = max(lens[(title, ii)], max(len(title), len(val)))
+
+    # printed_table_list_header = []
+    for ii, to_print in enumerate(table_list[0]):
+        for title, val in to_print.items():
+            print('{0:^{1}}'.format(title, lens[(title, ii)]), end=" ")
+    for i in table_list:
+        print()
+        for ii, to_print in enumerate(i):
+            for title, val in to_print.items():
+                print('{0:^{1}}'.format(val, lens[(title, ii)]), end=" ", flush=True)
+    print()

+ 1 - 0
docs/en/projects.md

@@ -6,6 +6,7 @@ We have been doing research on federated learning for several years and publishe
 
 We have released the following implementations of federated learning applications:
 
+- Federated Multiple Task Learning: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/mas) for [MAS: Towards Resource-Efficient Federated Multiple-Task Learning](https://arxiv.org/abs/2307.11285) (_ICCV'2023_)
 - FedReID: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedreid) for [Performance Optimization for Federated Person Re-identification via Benchmark Analysis](https://dl.acm.org/doi/10.1145/3394171.3413814) (_ACMMM'2020_).
 - FedSSL: [[code]](https://github.com/EasyFL-AI/EasyFL/tree/master/applications/fedssl) for two papers: [Divergence-aware Federated Self-Supervised Learning](https://openreview.net/forum?id=oVE1z8NlNe) (_ICLR'2022_)  and [Collaborative Unsupervised Visual Representation Learning From Decentralized Data](https://openaccess.thecvf.com/content/ICCV2021/html/Zhuang_Collaborative_Unsupervised_Visual_Representation_Learning_From_Decentralized_Data_ICCV_2021_paper.html) (_ICCV'2021_)