# Customize trainer to control the training process

In this tutorial, you will learn how to create and customize your own trainer to control the training process, make predictions, and aggregate results to meet your specific needs. We will first introduce you to the interfaces of the TrainerBase class that you need to implement. Then, we will provide a toy example of the FedProx algorithm (please note that this is just a toy example and should not be used in production) to help you better understand the concept of trainer customization.

## TrainerBase Class

### Basics

The TrainerBase Class is the base for all Homo-NN trainer in FATE. To create a custom trainer, you need to subclass the TrainerBase class located in [federatedml.homo.trainer_base](../../../../python/federatedml/nn/homo/trainer/trainer_base.py). There are two required functions that you must implement:

- The 'train()' function: This function takes five parameters: a training dataset instance (must be a subclass of Dataset), a validation dataset instance (also a subclass of Dataset), an optimizer instance with initialized training parameters, a loss function, and an extra data dictionary that may contain additional data for a warmstart task. In this function, you can define the process of client-side training and federation for a Homo-NN task.

- The 'server_aggregate_procedure()' function: This function takes one parameter, an extra data dictionary that may contain additional data for a warmstart task. It is called by the server and is where you can define the aggregation process.

There is also an optional 'predict()' function that takes one parameter, a dataset, and allows you to define how your trainer makes predictions. If you want to use the FATE framework, you need to ensure that your return data is formatted correctly so that FATE can display it correctly (we will cover this in a later tutorial)."

In the Homo-NN client component, the 'set_model()' function is used to set the initialized model to the trainer. When developing your trainer, you can use 'set_model()' to set your model, and then access it using 'self.model' within the trainer.

Here display the source code of these interfaces:

In [None]:
class TrainerBase(object):

 def __init__(self, **kwargs):
 ...
 self._model = None
 ...
 
 ...
 
 @property
 def model(self):
 if not hasattr(self, '_model'):
 raise AttributeError(
 'model variable is not initialized, remember to call'
 ' super(your_class, self).__init__()')
 if self._model is None:
 raise AttributeError(
 'model is not set, use set_model() function to set training model')

 return self._model

 @model.setter
 def model(self, val):
 self._model = val

 @abc.abstractmethod
 def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
 """
 train_set : A Dataset Instance, must be a instance of subclass of Dataset (federatedml.nn.dataset.base),
 for example, TableDataset() (from federatedml.nn.dataset.table)

 validate_set : A Dataset Instance, but optional must be a instance of subclass of Dataset
 (federatedml.nn.dataset.base), for example, TableDataset() (from federatedml.nn.dataset.table)

 optimizer : A pytorch optimizer class instance, for example, t.optim.Adam(), t.optim.SGD()

 loss : A pytorch Loss class, for example, nn.BECLoss(), nn.CrossEntropyLoss()
 """
 pass

 @abc.abstractmethod
 def predict(self, dataset):
 pass

 @abc.abstractmethod
 def server_aggregate_procedure(self, extra_data={}):
 pass

### Fed mode/ local mode

The Trainer has an attribute 'self.fed_mode' which is set to True when running a federated task. You can use this variable to determine whether your trainer is running in federated mode or in local debug mode. If you want to test the trainer locally, you can use the 'local_mode()' function to set 'self.fed_mode' to False.

## Example: Develop A Toy FedProx

To help you understand how to implement these functions, we will provide a concrete example by demonstrating a toy implement of the FedProx algorithm from https://arxiv.org/abs/1812.06127. In FedProx, the training process is slightly different from the standard FedAVG algorithm as it requires the computation of a proximal term from the current model and the global model when calculating the loss. We will walk you through the code with comments step by step.

### Toy FedProx

Here is the code for the trainer, which is saved in the federatedml.nn.homo.trainer module. This trainer implements two functions: train and server_aggregate_procedure. These functions enable the completion of a simple training task. The code includes comments to provide further details.

In [32]:
from pipeline.component.nn import save_to_fate

In [37]:
%%save_to_fate trainer fedprox.py
import copy
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from torch.utils.data import DataLoader
# We need to use aggregator client&server class for federation
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer
# We use LOGGER to output logs
from federatedml.util import LOGGER


class ToyFedProxTrainer(TrainerBase):

 def __init__(self, epochs, batch_size, u):
 super(ToyFedProxTrainer, self).__init__()
 # trainer parameters
 self.epochs = epochs
 self.batch_size = batch_size
 self.u = u

 # Given two model, we compute the proximal term
 def _proximal_term(self, model_a, model_b):
 diff_ = 0
 for p1, p2 in zip(model_a.parameters(), model_b.parameters()):
 diff_ += t.sqrt((p1-p2.detach())**2).sum()
 return diff_

 # implement the train function, this function will be called by client side
 # contains the local training process and the federation part
 def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):
 
 sample_num = len(train_set)
 aggregator = None
 if self.fed_mode:
 aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num, 
 communicate_match_suffix='fedprox') # initialize aggregator

 # set dataloader
 dl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)

 for epoch in range(self.epochs):
 
 # the local training process
 LOGGER.debug('running epoch {}'.format(epoch))
 global_model = copy.deepcopy(self.model)
 loss_sum = 0

 # batch training process
 for batch_data, label in dl:
 optimizer.zero_grad()
 pred = self.model(batch_data)
 loss_term_a = loss(pred, label)
 loss_term_b = self._proximal_term(self.model, global_model)
 loss_ = loss_term_a + (self.u/2) * loss_term_b
 loss_.backward()
 loss_sum += float(loss_.detach().numpy())
 optimizer.step()

 # print loss
 LOGGER.debug('epoch loss is {}'.format(loss_sum))

 # the aggregation process
 if aggregator is not None:
 self.model = aggregator.model_aggregation(self.model)
 converge_status = aggregator.loss_aggregation(loss_sum)

 # implement the aggregation function, this function will be called by the sever side
 def server_aggregate_procedure(self, extra_data={}):
 
 # initialize aggregator
 if self.fed_mode:
 aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')

 # the aggregation process is simple: every epoch the server aggregate model and loss once
 for i in range(self.epochs):
 aggregator.model_aggregation()
 merge_loss, _ = aggregator.loss_aggregation()


# Local Test

We can use local_mode() to locally test our new FedProx trainer.

In [31]:
import torch as t
from federatedml.nn.dataset.table import TableDataset

model = t.nn.Sequential(
 t.nn.Linear(30, 1),
 t.nn.Sigmoid()
)

ds = TableDataset()
ds.load('../../../../examples/data/breast_homo_guest.csv')

trainer = ToyFedProxTrainer(10, 128, u=0.1)
trainer.set_model(model)
opt = t.optim.Adam(model.parameters(), lr=0.01)
loss = t.nn.BCELoss()

trainer.local_mode()
trainer.train(ds, None, opt, loss)


running epoch 0
epoch loss is 1.0665020644664764
running epoch 1
epoch loss is 0.9155551195144653
running epoch 2
epoch loss is 0.8021544218063354
running epoch 3
epoch loss is 0.7173515558242798
running epoch 4
epoch loss is 0.6532197296619415
running epoch 5
epoch loss is 0.6034933030605316
running epoch 6
epoch loss is 0.5636875331401825
running epoch 7
epoch loss is 0.5307579338550568
running epoch 8
epoch loss is 0.5026698857545853
running epoch 9
epoch loss is 0.47806812822818756


Great! It can work! Then we will submit a federated task to see if our trainer works correctly.

## Submit a New Task to Test ToyFedProx

In [38]:
# torch
import torch as t
from torch import nn
from pipeline import fate_torch_hook
fate_torch_hook(t)
# pipeline
from pipeline.component.homo_nn import HomoNN, TrainerParam # HomoNN Component, TrainerParam for setting trainer parameter
from pipeline.backend.pipeline import PipeLine # pipeline class
from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation
from pipeline.interface import Data # Data Interaces for defining data flow


# create a pipeline to submitting the job
guest = 9999
host = 10000
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)

# read uploaded dataset
train_data_0 = {"name": "breast_homo_guest", "namespace": "experiment"}
train_data_1 = {"name": "breast_homo_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)

# The transform component converts the uploaded data to the DATE standard format
data_transform_0 = DataTransform(name='data_transform_0')
data_transform_0.get_party_instance(
 role='guest', party_id=guest).component_param(
 with_label=True, output_format="dense")
data_transform_0.get_party_instance(
 role='host', party_id=host).component_param(
 with_label=True, output_format="dense")

"""
Define Pytorch model/ optimizer and loss
"""
model = nn.Sequential(
 nn.Linear(30, 1),
 nn.Sigmoid()
)
loss = nn.BCELoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.01)


"""
Create Homo-NN Component
"""
nn_component = HomoNN(name='nn_0',
 model=model, # set model
 loss=loss, # set loss
 optimizer=optimizer, # set optimizer
 # Here we use fedavg trainer
 # TrainerParam passes parameters to fedavg_trainer, see below for details about Trainer
 trainer=TrainerParam(trainer_name='fedprox', epochs=3, batch_size=128, u=0.5),
 torch_seed=100 # random seed
 )

# define work flow
pipeline.add_component(reader_0)
pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))
pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))
pipeline.compile()
pipeline.fit()

[32m2022-12-26 12:22:09.789[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202212261222090031360
[0m
[32m2022-12-26 12:22:09.821[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[32m2022-12-26 12:22:10.837[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:01[0m
[0mm2022-12-26 12:22:11.890[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2022-12-26 12:22:11.892[0m | [1mINFO [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m127[0m - [1m[80D[1A[KRunning component reader_0, time elapse: 0:00:02[0m
[32m2022-12-26 12:22:12.916[0m | [1mINFO [0m | [36mpipeli

Yes! This trainer can work correctly. In the next tutorial, we will show you how to use the trainer user interfaces to improve this trainer. These interfaces allow you to return formatted prediction results, evaluate your model, save your model, and display loss curves and performance scores on the fateboard. By using these interfaces, you can enhance the functionality of the trainer and make it more user-friendly.