# Using FATE-interfaces

In this tutorial, we will demonstrate how to use the trainer user interfaces to return formatted prediction results, evaluate the performance of your model, save your model, and display loss curves and performance scores on the dashboard. These interfaces allow your trainer to integrate with the FATE framework and make it easier to work with.

In this tutorial, we will continue to develop our toy FedProx trainer.

## The Toy Implementation of FedProx

[In last tutorial](./Homo-NN-Customize-Trainer.ipynb), we provide a concrete example by demonstrating a toy implement of the FedProx algorithm from https://arxiv.org/abs/1812.06127. In FedProx, the training process is slightly different from the standard FedAVG algorithm as it requires the computation of a proximal term from the current model and the global model when calculating the loss. The codes is here:

In [2]:
from pipeline.component.nn import save_to_fate

In [4]:
%%save_to_fate trainer fedprox.py
import copy
import torch as t
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from torch.utils.data import DataLoader
# We need to use aggregator client&server class for federation
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
# We use LOGGER to output logs
from federatedml.util import LOGGER


class ToyFedProxTrainer(TrainerBase):

 def __init__(self, epochs, batch_size, u):
 super(ToyFedProxTrainer, self).__init__()
 # trainer parameters
 self.epochs = epochs
 self.batch_size = batch_size
 self.u = u

 # Given two model, we compute the proximal term
 def _proximal_term(self, model_a, model_b):
 diff_ = 0
 for p1, p2 in zip(model_a.parameters(), model_b.parameters()):
 diff_ += (p1-p2.detach()).sum()
 return diff_

 # implement the train function, this function will be called by client side
 # contains the local training process and the federation part
 def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
 
 sample_num = len(train_set)
 aggregator = None
 if self.fed_mode:
 aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num, 
 communicate_match_suffix='fedprox') # initialize aggregator

 # set dataloader
 dl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)

 for epoch in range(self.epochs):
 
 # the local training process
 LOGGER.debug('running epoch {}'.format(epoch))
 global_model = copy.deepcopy(self.model)
 loss_sum = 0

 # batch training process
 for batch_data, label in dl:
 optimizer.zero_grad()
 pred = self.model(batch_data)
 loss_term_a = loss(pred, label)
 loss_term_b = self._proximal_term(self.model, global_model)
 loss_ = loss_term_a + (self.u/2) * loss_term_b
 loss_.backward()
 loss_sum += float(loss_.detach().numpy())
 optimizer.step()

 # pring loss
 LOGGER.debug('epoch loss is {}'.format(loss_sum))

 # the aggregation process
 if aggregator is not None:
 self.model = aggregator.model_aggregation(self.model)
 converge_status = aggregator.loss_aggregation(loss_sum)

 # implement the aggregation function, this function will be called by the sever side
 def server_aggregate_procedure(self, extra_data={}):
 
 # initialize aggregator
 if self.fed_mode:
 aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')

 # the aggregation process is simple: every epoch the server aggregate model and loss once
 for i in range(self.epochs):
 aggregator.model_aggregation()
 merge_loss, _ = aggregator.loss_aggregation()


## User interfaces

Now we introduce you the user-interfaces offered by the TrainerBase class, we will use these function to improve our trainer.

### format_predict_result

This function will organize your prediction results and return a StdReturnFormat object, which wraps the results. You can use this function at the end of your prediction function to return a standardized format that the FATE framework can parse and display them on the fateboard. This standardized format also allows downstream components, such as the evaluation component, to use the prediction results.

This function takes four arguments:
- sample_ids: a list of sample IDs
- predict_result: a tensor of prediction scores
- true_label: a tensor of true labels
- task_type: the type of task being performed. The default is 'auto', which will automatically infer the task type. Other options include 'binary', 'multi', and 'regression'. Currently, the FATE dashboard only supports the display of binary/multi classification and regression tasks. If 'auto' is chosen, the task type will be inferred automatically.

**We will implement a prediction in FedProx trainer later.**

In [5]:
import torch as t 
from typing import List

def format_predict_result(self, sample_ids: List, predict_result: t.Tensor,
 true_label: t.Tensor, task_type: str = None):
 ...

### callback_metric & callback_loss

As the names suggest, these two functions enable you to save data points and display custom evaluation metrics and loss curves on the fateboard.

When using the callback metric function, you need to provide the metric name, a float value, and specify the metric type ('train' or 'validate') and the epoch index. When using the callback loss function, you need to provide a float loss value and the epoch index. Your data will be displayed on the fateboard.

In [6]:
def callback_metric(self, metric_name: str, value: float, metric_type='train', epoch_idx=0):
 ...

def callback_loss(self, loss: float, epoch_idx: int):
 ...

### summary

This function allows you to save a summary of the training process, such as the loss history and the best epoch, in a dictionary. You can retrieve this summary from the pipeline once the task is completed.

In [None]:
def summary(self, summary_dict: dict):
 ...

### save & checkpoint
 
 You can save your models using the 'save' and set model checkpoint using 'checkpoint' function. It's important to note that:

- 'save' only stores the model in memory, so the model you save will be the one that was last saved using the 'save' function.
- 'checkpoint' directly saves the model to disk.
- 'save' should only be called on the client side (in the 'train' function), while 'checkpoint' should be called on both the client and server side (in the 'train' and 'server_aggregate_procedure' functions) to ensure that the checkpoint mechanism works correctly.

The 'extra_data' parameter in the function allows you to save additional data in a dictionary. This can be useful when warm-starting a model, as you can retrieve the saved data using the 'extra_data' parameter in the 'train' and 'server_aggregate_procedure' functions.

In [None]:
def save(
 self,
 model=None,
 epoch_idx=-1,
 optimizer=None,
 converge_status=False,
 loss_history=None,
 best_epoch=-1,
 extra_data={}): ...

def checkpoint(
 self,
 epoch_idx,
 model=None,
 optimizer=None,
 converge_status=False,
 loss_history=None,
 best_epoch=-1,
 extra_data={}): ...

### evaluation

This interface allows you to evaluate your model by automatically computing various performance metrics
The metrics that are computed depend on the type of your dataset and task

- Binary classification: 'AUC' and 'ks'
- Multi-class classification: 'accuracy', 'precision', and 'recall'
- Regression: 'rmse' and 'mae'
 
You can specify the type of your dataset ('train' or 'validate') and the task type ('binary', 'multi', or 'regression') in the parameters. If no task type is specified, it will be automatically inferred from your scores and labels.

In [None]:
def evaluation(self, sample_ids: list, pred_scores: t.Tensor, label: t.Tensor, dataset_type='train',
 epoch_idx=0, task_type=None):
 ...

## Improved FedProx Trainer

In this section, we will use the interfaces introduced earlier to improve our FedProx Trainer and make it a more comprehensive training tool. We:

- we implement the predict function, and it return a formatted result
- add evaluation function
- save model at the end of the training
- callback loss to save loss curves
- we compute accuracy scores and display then using callback metrics.

In [8]:
from pipeline.component.nn import save_to_fate

In [13]:
%%save_to_fate trainer fedprox_v2.py
import copy
import torch as t
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from federatedml.nn.dataset.base import Dataset
from torch.utils.data import DataLoader
# We need to use aggregator client&server class for federation
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
# We use LOGGER to output logs
from federatedml.util import LOGGER


class ToyFedProxTrainer(TrainerBase):

 def __init__(self, epochs, batch_size, u):
 super(ToyFedProxTrainer, self).__init__()
 # trainer parameters
 self.epochs = epochs
 self.batch_size = batch_size
 self.u = u

 # Given two model, we compute the proximal term
 def _proximal_term(self, model_a, model_b):
 diff_ = 0
 for p1, p2 in zip(model_a.parameters(), model_b.parameters()):
 diff_ += t.sqrt((p1-p2.detach())**2).sum()
 return diff_

 # implement the train function, this function will be called by client side
 # contains the local training process and the federation part
 def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
 
 sample_num = len(train_set)
 aggregator = None
 if self.fed_mode:
 aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num, 
 communicate_match_suffix='fedprox') # initialize aggregator

 # set dataloader
 dl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)

 loss_history = []
 for epoch in range(self.epochs):
 
 # the local training process
 LOGGER.debug('running epoch {}'.format(epoch))
 global_model = copy.deepcopy(self.model)
 loss_sum = 0

 # batch training process
 for batch_data, label in dl:
 optimizer.zero_grad()
 pred = self.model(batch_data)
 loss_term_a = loss(pred, label)
 loss_term_b = self._proximal_term(self.model, global_model)
 loss_ = loss_term_a + (self.u/2) * loss_term_b
 LOGGER.debug('loss is {} loss a is {} loss b is {}'.format(loss_, loss_term_a, loss_term_b))
 loss_.backward()
 loss_sum += float(loss_.detach().numpy())
 optimizer.step()
 
 # print loss
 LOGGER.debug('epoch loss is {}'.format(loss_sum))
 loss_history.append(loss_sum)

 # we callback loss here
 self.callback_loss(loss_sum, epoch)

 # we evaluate out model here
 sample_ids, preds, labels = self._predict(train_set)
 self.evaluation(sample_ids, preds, labels, 'train', task_type='binary', epoch_idx=epoch)

 # we manually compute accuracy:
 acc = ((preds > 0.5 + 0) == labels).sum() / len(labels)
 acc = float(acc.detach().numpy())
 self.callback_metric('my_accuracy', acc, epoch_idx=epoch)

 # the aggregation process
 if aggregator is not None:
 self.model = aggregator.model_aggregation(self.model)
 converge_status = aggregator.loss_aggregation(loss_sum)

 # We will save model at the end of the training
 self.save(self.model, epoch, optimizer)
 # We will save model summary
 self.summary({'loss_history': loss_history})

 # implement the aggregation function, this function will be called by the sever side
 def server_aggregate_procedure(self, extra_data={}):
 
 # initialize aggregator
 if self.fed_mode:
 aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')

 # the aggregation process is simple: every epoch the server aggregate model and loss once
 for i in range(self.epochs):
 aggregator.model_aggregation()
 merge_loss, _ = aggregator.loss_aggregation()


 def _predict(self, dataset: Dataset):
 len_ = len(dataset)
 dl = DataLoader(dataset, batch_size=len_)
 preds, labels = None, None
 for data, l in dl:
 preds = self.model(data)
 labels = l
 sample_ids = dataset.get_sample_ids()
 return sample_ids, preds, labels

 # We implement the predict function here
 def predict(self, dataset):
 
 sample_ids, preds, labels = self._predict(dataset)
 return self.format_predict_result(sample_ids, preds, labels, 'binary')

## Submit a Pipeline

Here we submit a new pipeline to test our new trainer

In [14]:
# torch
import torch as t
from torch import nn
from pipeline import fate_torch_hook
fate_torch_hook(t)
# pipeline
from pipeline.component.homo_nn import HomoNN, TrainerParam # HomoNN Component, TrainerParam for setting trainer parameter
from pipeline.backend.pipeline import PipeLine # pipeline class
from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation
from pipeline.interface import Data # Data Interaces for defining data flow


# create a pipeline to submitting the job
guest = 9999
host = 10000
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)

# read uploaded dataset
train_data_0 = {"name": "breast_homo_guest", "namespace": "experiment"}
train_data_1 = {"name": "breast_homo_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)

# The transform component converts the uploaded data to the DATE standard format
data_transform_0 = DataTransform(name='data_transform_0')
data_transform_0.get_party_instance(
 role='guest', party_id=guest).component_param(
 with_label=True, output_format="dense")
data_transform_0.get_party_instance(
 role='host', party_id=host).component_param(
 with_label=True, output_format="dense")

"""
Define Pytorch model/ optimizer and loss
"""
model = nn.Sequential(
 nn.Linear(30, 1),
 nn.Sigmoid()
)
loss = nn.BCELoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.01)


"""
Create Homo-NN Component
"""
nn_component = HomoNN(name='nn_0',
 model=model, # set model
 loss=loss, # set loss
 optimizer=optimizer, # set optimizer
 # Here we use fedavg trainer
 # TrainerParam passes parameters to fedavg_trainer, see below for details about Trainer
 trainer=TrainerParam(trainer_name='fedprox_v2', epochs=3, batch_size=128, u=0.5),
 torch_seed=100 # random seed
 )

# define work flow
pipeline.add_component(reader_0)
pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))
pipeline.compile()
pipeline.fit()

[32m2022-12-26 16:13:37.650[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202212261613368298170
[0m
[32m2022-12-26 16:13:37.665[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[0mm2022-12-26 16:13:38.698[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2022-12-26 16:13:38.700[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:01[0m
[32m2022-12-26 16:13:39.722[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:02[0m
[32m2022-12-26 16:13:40.754[0m | [1mINFO [0m | [36