{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Customize trainer to control the training process\n", "\n", "In this tutorial, you will learn how to create and customize your own trainer to control the training process, make predictions, and aggregate results to meet your specific needs. We will first introduce you to the interfaces of the TrainerBase class that you need to implement. Then, we will provide a toy example of the FedProx algorithm (please note that this is just a toy example and should not be used in production) to help you better understand the concept of trainer customization." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## TrainerBase Class\n", "\n", "### Basics\n", "\n", "The TrainerBase Class is the base for all Homo-NN trainer in FATE. To create a custom trainer, you need to subclass the TrainerBase class located in [federatedml.homo.trainer_base](../../../../python/federatedml/nn/homo/trainer/trainer_base.py). There are two required functions that you must implement:\n", "\n", "- The 'train()' function: This function takes five parameters: a training dataset instance (must be a subclass of Dataset), a validation dataset instance (also a subclass of Dataset), an optimizer instance with initialized training parameters, a loss function, and an extra data dictionary that may contain additional data for a warmstart task. In this function, you can define the process of client-side training and federation for a Homo-NN task.\n", "\n", "- The 'server_aggregate_procedure()' function: This function takes one parameter, an extra data dictionary that may contain additional data for a warmstart task. It is called by the server and is where you can define the aggregation process.\n", "\n", "There is also an optional 'predict()' function that takes one parameter, a dataset, and allows you to define how your trainer makes predictions. If you want to use the FATE framework, you need to ensure that your return data is formatted correctly so that FATE can display it correctly (we will cover this in a later tutorial).\"\n", "\n", "In the Homo-NN client component, the 'set_model()' function is used to set the initialized model to the trainer. When developing your trainer, you can use 'set_model()' to set your model, and then access it using 'self.model' within the trainer.\n", "\n", "Here display the source code of these interfaces:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TrainerBase(object):\n", "\n", " def __init__(self, **kwargs):\n", " ...\n", " self._model = None\n", " ...\n", " \n", " ...\n", " \n", " @property\n", " def model(self):\n", " if not hasattr(self, '_model'):\n", " raise AttributeError(\n", " 'model variable is not initialized, remember to call'\n", " ' super(your_class, self).__init__()')\n", " if self._model is None:\n", " raise AttributeError(\n", " 'model is not set, use set_model() function to set training model')\n", "\n", " return self._model\n", "\n", " @model.setter\n", " def model(self, val):\n", " self._model = val\n", "\n", " @abc.abstractmethod\n", " def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):\n", " \"\"\"\n", " train_set : A Dataset Instance, must be a instance of subclass of Dataset (federatedml.nn.dataset.base),\n", " for example, TableDataset() (from federatedml.nn.dataset.table)\n", "\n", " validate_set : A Dataset Instance, but optional must be a instance of subclass of Dataset\n", " (federatedml.nn.dataset.base), for example, TableDataset() (from federatedml.nn.dataset.table)\n", "\n", " optimizer : A pytorch optimizer class instance, for example, t.optim.Adam(), t.optim.SGD()\n", "\n", " loss : A pytorch Loss class, for example, nn.BECLoss(), nn.CrossEntropyLoss()\n", " \"\"\"\n", " pass\n", "\n", " @abc.abstractmethod\n", " def predict(self, dataset):\n", " pass\n", "\n", " @abc.abstractmethod\n", " def server_aggregate_procedure(self, extra_data={}):\n", " pass" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Fed mode/ local mode\n", "\n", "The Trainer has an attribute 'self.fed_mode' which is set to True when running a federated task. You can use this variable to determine whether your trainer is running in federated mode or in local debug mode. If you want to test the trainer locally, you can use the 'local_mode()' function to set 'self.fed_mode' to False." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Example: Develop A Toy FedProx\n", "\n", "To help you understand how to implement these functions, we will provide a concrete example by demonstrating a toy implement of the FedProx algorithm from https://arxiv.org/abs/1812.06127. In FedProx, the training process is slightly different from the standard FedAVG algorithm as it requires the computation of a proximal term from the current model and the global model when calculating the loss. We will walk you through the code with comments step by step." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Toy FedProx\n", "\n", "Here is the code for the trainer, which is saved in the federatedml.nn.homo.trainer module. This trainer implements two functions: train and server_aggregate_procedure. These functions enable the completion of a simple training task. The code includes comments to provide further details." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "from pipeline.component.nn import save_to_fate" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "%%save_to_fate trainer fedprox.py\n", "import copy\n", "from federatedml.nn.homo.trainer.trainer_base import TrainerBase\n", "from torch.utils.data import DataLoader\n", "# We need to use aggregator client&server class for federation\n", "from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer\n", "# We use LOGGER to output logs\n", "from federatedml.util import LOGGER\n", "\n", "\n", "class ToyFedProxTrainer(TrainerBase):\n", "\n", " def __init__(self, epochs, batch_size, u):\n", " super(ToyFedProxTrainer, self).__init__()\n", " # trainer parameters\n", " self.epochs = epochs\n", " self.batch_size = batch_size\n", " self.u = u\n", "\n", " # Given two model, we compute the proximal term\n", " def _proximal_term(self, model_a, model_b):\n", " diff_ = 0\n", " for p1, p2 in zip(model_a.parameters(), model_b.parameters()):\n", " diff_ += t.sqrt((p1-p2.detach())**2).sum()\n", " return diff_\n", "\n", " # implement the train function, this function will be called by client side\n", " # contains the local training process and the federation part\n", " def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):\n", " \n", " sample_num = len(train_set)\n", " aggregator = None\n", " if self.fed_mode:\n", " aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num, \n", " communicate_match_suffix='fedprox') # initialize aggregator\n", "\n", " # set dataloader\n", " dl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)\n", "\n", " for epoch in range(self.epochs):\n", " \n", " # the local training process\n", " LOGGER.debug('running epoch {}'.format(epoch))\n", " global_model = copy.deepcopy(self.model)\n", " loss_sum = 0\n", "\n", " # batch training process\n", " for batch_data, label in dl:\n", " optimizer.zero_grad()\n", " pred = self.model(batch_data)\n", " loss_term_a = loss(pred, label)\n", " loss_term_b = self._proximal_term(self.model, global_model)\n", " loss_ = loss_term_a + (self.u/2) * loss_term_b\n", " loss_.backward()\n", " loss_sum += float(loss_.detach().numpy())\n", " optimizer.step()\n", "\n", " # print loss\n", " LOGGER.debug('epoch loss is {}'.format(loss_sum))\n", "\n", " # the aggregation process\n", " if aggregator is not None:\n", " self.model = aggregator.model_aggregation(self.model)\n", " converge_status = aggregator.loss_aggregation(loss_sum)\n", "\n", " # implement the aggregation function, this function will be called by the sever side\n", " def server_aggregate_procedure(self, extra_data={}):\n", " \n", " # initialize aggregator\n", " if self.fed_mode:\n", " aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')\n", "\n", " # the aggregation process is simple: every epoch the server aggregate model and loss once\n", " for i in range(self.epochs):\n", " aggregator.model_aggregation()\n", " merge_loss, _ = aggregator.loss_aggregation()\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Local Test\n", "\n", "We can use local_mode() to locally test our new FedProx trainer." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "running epoch 0\n", "epoch loss is 1.0665020644664764\n", "running epoch 1\n", "epoch loss is 0.9155551195144653\n", "running epoch 2\n", "epoch loss is 0.8021544218063354\n", "running epoch 3\n", "epoch loss is 0.7173515558242798\n", "running epoch 4\n", "epoch loss is 0.6532197296619415\n", "running epoch 5\n", "epoch loss is 0.6034933030605316\n", "running epoch 6\n", "epoch loss is 0.5636875331401825\n", "running epoch 7\n", "epoch loss is 0.5307579338550568\n", "running epoch 8\n", "epoch loss is 0.5026698857545853\n", "running epoch 9\n", "epoch loss is 0.47806812822818756\n" ] } ], "source": [ "import torch as t\n", "from federatedml.nn.dataset.table import TableDataset\n", "\n", "model = t.nn.Sequential(\n", " t.nn.Linear(30, 1),\n", " t.nn.Sigmoid()\n", ")\n", "\n", "ds = TableDataset()\n", "ds.load('../../../../examples/data/breast_homo_guest.csv')\n", "\n", "trainer = ToyFedProxTrainer(10, 128, u=0.1)\n", "trainer.set_model(model)\n", "opt = t.optim.Adam(model.parameters(), lr=0.01)\n", "loss = t.nn.BCELoss()\n", "\n", "trainer.local_mode()\n", "trainer.train(ds, None, opt, loss)\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Great! It can work! Then we will submit a federated task to see if our trainer works correctly.\n", "\n", "## Submit a New Task to Test ToyFedProx" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-26 12:22:09.789\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212261222090031360\n", "\u001b[0m\n", "\u001b[32m2022-12-26 12:22:09.821\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[32m2022-12-26 12:22:10.837\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n", "\u001b[0mm2022-12-26 12:22:11.890\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-26 12:22:11.892\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:02\u001b[0m\n", "\u001b[32m2022-12-26 12:22:12.916\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-26 12:22:14.015\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-26 12:22:15.080\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-26 12:22:16.241\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-26 12:22:17.336\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:07\u001b[0m\n", "\u001b[32m2022-12-26 12:22:18.413\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:08\u001b[0m\n", "\u001b[32m2022-12-26 12:22:19.478\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:09\u001b[0m\n", "\u001b[32m2022-12-26 12:22:20.570\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:10\u001b[0m\n", "\u001b[0mm2022-12-26 12:22:22.743\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-26 12:22:22.750\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:12\u001b[0m\n", "\u001b[32m2022-12-26 12:22:23.844\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:14\u001b[0m\n", "\u001b[32m2022-12-26 12:22:24.902\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:15\u001b[0m\n", "\u001b[32m2022-12-26 12:22:26.040\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:16\u001b[0m\n", "\u001b[32m2022-12-26 12:22:27.112\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:17\u001b[0m\n", "\u001b[32m2022-12-26 12:22:28.163\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:18\u001b[0m\n", "\u001b[32m2022-12-26 12:22:29.234\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:19\u001b[0m\n", "\u001b[32m2022-12-26 12:22:30.286\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:20\u001b[0m\n", "\u001b[32m2022-12-26 12:22:31.338\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:21\u001b[0m\n", "\u001b[32m2022-12-26 12:22:32.421\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:22\u001b[0m\n", "\u001b[32m2022-12-26 12:22:33.498\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:23\u001b[0m\n", "\u001b[32m2022-12-26 12:22:34.584\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:24\u001b[0m\n", "\u001b[32m2022-12-26 12:22:35.629\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:25\u001b[0m\n", "\u001b[0mm2022-12-26 12:22:37.781\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-26 12:22:37.788\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:27\u001b[0m\n", "\u001b[32m2022-12-26 12:22:38.862\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:29\u001b[0m\n", "\u001b[32m2022-12-26 12:22:39.914\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:30\u001b[0m\n", "\u001b[32m2022-12-26 12:22:40.966\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:31\u001b[0m\n", "\u001b[32m2022-12-26 12:22:42.090\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:32\u001b[0m\n", "\u001b[32m2022-12-26 12:22:43.141\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:33\u001b[0m\n", "\u001b[32m2022-12-26 12:22:44.184\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:34\u001b[0m\n", "\u001b[32m2022-12-26 12:22:45.894\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:36\u001b[0m\n", "\u001b[32m2022-12-26 12:22:46.966\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:37\u001b[0m\n", "\u001b[32m2022-12-26 12:22:48.009\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:38\u001b[0m\n", "\u001b[32m2022-12-26 12:22:49.069\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:39\u001b[0m\n", "\u001b[32m2022-12-26 12:22:50.120\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:40\u001b[0m\n", "\u001b[32m2022-12-26 12:22:51.172\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:41\u001b[0m\n", "\u001b[32m2022-12-26 12:22:52.225\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:42\u001b[0m\n", "\u001b[32m2022-12-26 12:22:53.272\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:43\u001b[0m\n", "\u001b[32m2022-12-26 12:22:54.318\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:44\u001b[0m\n", "\u001b[32m2022-12-26 12:22:55.357\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:45\u001b[0m\n", "\u001b[32m2022-12-26 12:22:56.444\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:46\u001b[0m\n", "\u001b[32m2022-12-26 12:22:59.842\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212261222090031360\u001b[0m\n", "\u001b[32m2022-12-26 12:22:59.847\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:50\u001b[0m\n" ] } ], "source": [ "# torch\n", "import torch as t\n", "from torch import nn\n", "from pipeline import fate_torch_hook\n", "fate_torch_hook(t)\n", "# pipeline\n", "from pipeline.component.homo_nn import HomoNN, TrainerParam # HomoNN Component, TrainerParam for setting trainer parameter\n", "from pipeline.backend.pipeline import PipeLine # pipeline class\n", "from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation\n", "from pipeline.interface import Data # Data Interaces for defining data flow\n", "\n", "\n", "# create a pipeline to submitting the job\n", "guest = 9999\n", "host = 10000\n", "arbiter = 10000\n", "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)\n", "\n", "# read uploaded dataset\n", "train_data_0 = {\"name\": \"breast_homo_guest\", \"namespace\": \"experiment\"}\n", "train_data_1 = {\"name\": \"breast_homo_host\", \"namespace\": \"experiment\"}\n", "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)\n", "reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)\n", "\n", "# The transform component converts the uploaded data to the DATE standard format\n", "data_transform_0 = DataTransform(name='data_transform_0')\n", "data_transform_0.get_party_instance(\n", " role='guest', party_id=guest).component_param(\n", " with_label=True, output_format=\"dense\")\n", "data_transform_0.get_party_instance(\n", " role='host', party_id=host).component_param(\n", " with_label=True, output_format=\"dense\")\n", "\n", "\"\"\"\n", "Define Pytorch model/ optimizer and loss\n", "\"\"\"\n", "model = nn.Sequential(\n", " nn.Linear(30, 1),\n", " nn.Sigmoid()\n", ")\n", "loss = nn.BCELoss()\n", "optimizer = t.optim.Adam(model.parameters(), lr=0.01)\n", "\n", "\n", "\"\"\"\n", "Create Homo-NN Component\n", "\"\"\"\n", "nn_component = HomoNN(name='nn_0',\n", " model=model, # set model\n", " loss=loss, # set loss\n", " optimizer=optimizer, # set optimizer\n", " # Here we use fedavg trainer\n", " # TrainerParam passes parameters to fedavg_trainer, see below for details about Trainer\n", " trainer=TrainerParam(trainer_name='fedprox', epochs=3, batch_size=128, u=0.5),\n", " torch_seed=100 # random seed\n", " )\n", "\n", "# define work flow\n", "pipeline.add_component(reader_0)\n", "pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))\n", "pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))\n", "pipeline.compile()\n", "pipeline.fit()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Yes! This trainer can work correctly. In the next tutorial, we will show you how to use the trainer user interfaces to improve this trainer. These interfaces allow you to return formatted prediction results, evaluate your model, save your model, and display loss curves and performance scores on the fateboard. By using these interfaces, you can enhance the functionality of the trainer and make it more user-friendly." ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853" } } }, "nbformat": 4, "nbformat_minor": 2 }