{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Using FATE-interfaces\n", "\n", "In this tutorial, we will demonstrate how to use the trainer user interfaces to return formatted prediction results, evaluate the performance of your model, save your model, and display loss curves and performance scores on the dashboard. These interfaces allow your trainer to integrate with the FATE framework and make it easier to work with.\n", "\n", "In this tutorial, we will continue to develop our toy FedProx trainer." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## The Toy Implementation of FedProx\n", "\n", "[In last tutorial](./Homo-NN-Customize-Trainer.ipynb), we provide a concrete example by demonstrating a toy implement of the FedProx algorithm from https://arxiv.org/abs/1812.06127. In FedProx, the training process is slightly different from the standard FedAVG algorithm as it requires the computation of a proximal term from the current model and the global model when calculating the loss. The codes is here:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from pipeline.component.nn import save_to_fate" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "%%save_to_fate trainer fedprox.py\n", "import copy\n", "import torch as t\n", "from federatedml.nn.homo.trainer.trainer_base import TrainerBase\n", "from torch.utils.data import DataLoader\n", "# We need to use aggregator client&server class for federation\n", "from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer\n", "# We use LOGGER to output logs\n", "from federatedml.util import LOGGER\n", "\n", "\n", "class ToyFedProxTrainer(TrainerBase):\n", "\n", " def __init__(self, epochs, batch_size, u):\n", " super(ToyFedProxTrainer, self).__init__()\n", " # trainer parameters\n", " self.epochs = epochs\n", " self.batch_size = batch_size\n", " self.u = u\n", "\n", " # Given two model, we compute the proximal term\n", " def _proximal_term(self, model_a, model_b):\n", " diff_ = 0\n", " for p1, p2 in zip(model_a.parameters(), model_b.parameters()):\n", " diff_ += (p1-p2.detach()).sum()\n", " return diff_\n", "\n", " # implement the train function, this function will be called by client side\n", " # contains the local training process and the federation part\n", " def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):\n", " \n", " sample_num = len(train_set)\n", " aggregator = None\n", " if self.fed_mode:\n", " aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num, \n", " communicate_match_suffix='fedprox') # initialize aggregator\n", "\n", " # set dataloader\n", " dl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)\n", "\n", " for epoch in range(self.epochs):\n", " \n", " # the local training process\n", " LOGGER.debug('running epoch {}'.format(epoch))\n", " global_model = copy.deepcopy(self.model)\n", " loss_sum = 0\n", "\n", " # batch training process\n", " for batch_data, label in dl:\n", " optimizer.zero_grad()\n", " pred = self.model(batch_data)\n", " loss_term_a = loss(pred, label)\n", " loss_term_b = self._proximal_term(self.model, global_model)\n", " loss_ = loss_term_a + (self.u/2) * loss_term_b\n", " loss_.backward()\n", " loss_sum += float(loss_.detach().numpy())\n", " optimizer.step()\n", "\n", " # pring loss\n", " LOGGER.debug('epoch loss is {}'.format(loss_sum))\n", "\n", " # the aggregation process\n", " if aggregator is not None:\n", " self.model = aggregator.model_aggregation(self.model)\n", " converge_status = aggregator.loss_aggregation(loss_sum)\n", "\n", " # implement the aggregation function, this function will be called by the sever side\n", " def server_aggregate_procedure(self, extra_data={}):\n", " \n", " # initialize aggregator\n", " if self.fed_mode:\n", " aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')\n", "\n", " # the aggregation process is simple: every epoch the server aggregate model and loss once\n", " for i in range(self.epochs):\n", " aggregator.model_aggregation()\n", " merge_loss, _ = aggregator.loss_aggregation()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## User interfaces\n", "\n", "Now we introduce you the user-interfaces offered by the TrainerBase class, we will use these function to improve our trainer.\n", "\n", "### format_predict_result\n", "\n", "This function will organize your prediction results and return a StdReturnFormat object, which wraps the results. You can use this function at the end of your prediction function to return a standardized format that the FATE framework can parse and display them on the fateboard. This standardized format also allows downstream components, such as the evaluation component, to use the prediction results.\n", "\n", "This function takes four arguments:\n", "- sample_ids: a list of sample IDs\n", "- predict_result: a tensor of prediction scores\n", "- true_label: a tensor of true labels\n", "- task_type: the type of task being performed. The default is 'auto', which will automatically infer the task type. Other options include 'binary', 'multi', and 'regression'. Currently, the FATE dashboard only supports the display of binary/multi classification and regression tasks. If 'auto' is chosen, the task type will be inferred automatically.\n", "\n", "**We will implement a prediction in FedProx trainer later.**" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import torch as t \n", "from typing import List\n", "\n", "def format_predict_result(self, sample_ids: List, predict_result: t.Tensor,\n", " true_label: t.Tensor, task_type: str = None):\n", " ..." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### callback_metric & callback_loss\n", "\n", "As the names suggest, these two functions enable you to save data points and display custom evaluation metrics and loss curves on the fateboard.\n", "\n", "When using the callback metric function, you need to provide the metric name, a float value, and specify the metric type ('train' or 'validate') and the epoch index. When using the callback loss function, you need to provide a float loss value and the epoch index. Your data will be displayed on the fateboard." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def callback_metric(self, metric_name: str, value: float, metric_type='train', epoch_idx=0):\n", " ...\n", "\n", "def callback_loss(self, loss: float, epoch_idx: int):\n", " ..." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### summary\n", "\n", "This function allows you to save a summary of the training process, such as the loss history and the best epoch, in a dictionary. You can retrieve this summary from the pipeline once the task is completed." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def summary(self, summary_dict: dict):\n", " ..." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### save & checkpoint\n", " \n", " You can save your models using the 'save' and set model checkpoint using 'checkpoint' function. It's important to note that:\n", "\n", "- 'save' only stores the model in memory, so the model you save will be the one that was last saved using the 'save' function.\n", "- 'checkpoint' directly saves the model to disk.\n", "- 'save' should only be called on the client side (in the 'train' function), while 'checkpoint' should be called on both the client and server side (in the 'train' and 'server_aggregate_procedure' functions) to ensure that the checkpoint mechanism works correctly.\n", "\n", "The 'extra_data' parameter in the function allows you to save additional data in a dictionary. This can be useful when warm-starting a model, as you can retrieve the saved data using the 'extra_data' parameter in the 'train' and 'server_aggregate_procedure' functions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def save(\n", " self,\n", " model=None,\n", " epoch_idx=-1,\n", " optimizer=None,\n", " converge_status=False,\n", " loss_history=None,\n", " best_epoch=-1,\n", " extra_data={}): ...\n", "\n", "def checkpoint(\n", " self,\n", " epoch_idx,\n", " model=None,\n", " optimizer=None,\n", " converge_status=False,\n", " loss_history=None,\n", " best_epoch=-1,\n", " extra_data={}): ..." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### evaluation\n", "\n", "This interface allows you to evaluate your model by automatically computing various performance metrics\n", "The metrics that are computed depend on the type of your dataset and task\n", "\n", "- Binary classification: 'AUC' and 'ks'\n", "- Multi-class classification: 'accuracy', 'precision', and 'recall'\n", "- Regression: 'rmse' and 'mae'\n", " \n", "You can specify the type of your dataset ('train' or 'validate') and the task type ('binary', 'multi', or 'regression') in the parameters. If no task type is specified, it will be automatically inferred from your scores and labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def evaluation(self, sample_ids: list, pred_scores: t.Tensor, label: t.Tensor, dataset_type='train',\n", " epoch_idx=0, task_type=None):\n", " ..." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Improved FedProx Trainer\n", "\n", "In this section, we will use the interfaces introduced earlier to improve our FedProx Trainer and make it a more comprehensive training tool. We:\n", "\n", "- we implement the predict function, and it return a formatted result\n", "- add evaluation function\n", "- save model at the end of the training\n", "- callback loss to save loss curves\n", "- we compute accuracy scores and display then using callback metrics." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from pipeline.component.nn import save_to_fate" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "%%save_to_fate trainer fedprox_v2.py\n", "import copy\n", "import torch as t\n", "from federatedml.nn.homo.trainer.trainer_base import TrainerBase\n", "from federatedml.nn.dataset.base import Dataset\n", "from torch.utils.data import DataLoader\n", "# We need to use aggregator client&server class for federation\n", "from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient, SecureAggregatorServer\n", "# We use LOGGER to output logs\n", "from federatedml.util import LOGGER\n", "\n", "\n", "class ToyFedProxTrainer(TrainerBase):\n", "\n", " def __init__(self, epochs, batch_size, u):\n", " super(ToyFedProxTrainer, self).__init__()\n", " # trainer parameters\n", " self.epochs = epochs\n", " self.batch_size = batch_size\n", " self.u = u\n", "\n", " # Given two model, we compute the proximal term\n", " def _proximal_term(self, model_a, model_b):\n", " diff_ = 0\n", " for p1, p2 in zip(model_a.parameters(), model_b.parameters()):\n", " diff_ += t.sqrt((p1-p2.detach())**2).sum()\n", " return diff_\n", "\n", " # implement the train function, this function will be called by client side\n", " # contains the local training process and the federation part\n", " def train(self, train_set, validate_set=None, optimizer=None, loss=None, extra_data={}):\n", " \n", " sample_num = len(train_set)\n", " aggregator = None\n", " if self.fed_mode:\n", " aggregator = SecureAggregatorClient(True, aggregate_weight=sample_num, \n", " communicate_match_suffix='fedprox') # initialize aggregator\n", "\n", " # set dataloader\n", " dl = DataLoader(train_set, batch_size=self.batch_size, num_workers=4)\n", "\n", " loss_history = []\n", " for epoch in range(self.epochs):\n", " \n", " # the local training process\n", " LOGGER.debug('running epoch {}'.format(epoch))\n", " global_model = copy.deepcopy(self.model)\n", " loss_sum = 0\n", "\n", " # batch training process\n", " for batch_data, label in dl:\n", " optimizer.zero_grad()\n", " pred = self.model(batch_data)\n", " loss_term_a = loss(pred, label)\n", " loss_term_b = self._proximal_term(self.model, global_model)\n", " loss_ = loss_term_a + (self.u/2) * loss_term_b\n", " LOGGER.debug('loss is {} loss a is {} loss b is {}'.format(loss_, loss_term_a, loss_term_b))\n", " loss_.backward()\n", " loss_sum += float(loss_.detach().numpy())\n", " optimizer.step()\n", " \n", " # print loss\n", " LOGGER.debug('epoch loss is {}'.format(loss_sum))\n", " loss_history.append(loss_sum)\n", "\n", " # we callback loss here\n", " self.callback_loss(loss_sum, epoch)\n", "\n", " # we evaluate out model here\n", " sample_ids, preds, labels = self._predict(train_set)\n", " self.evaluation(sample_ids, preds, labels, 'train', task_type='binary', epoch_idx=epoch)\n", "\n", " # we manually compute accuracy:\n", " acc = ((preds > 0.5 + 0) == labels).sum() / len(labels)\n", " acc = float(acc.detach().numpy())\n", " self.callback_metric('my_accuracy', acc, epoch_idx=epoch)\n", "\n", " # the aggregation process\n", " if aggregator is not None:\n", " self.model = aggregator.model_aggregation(self.model)\n", " converge_status = aggregator.loss_aggregation(loss_sum)\n", "\n", " # We will save model at the end of the training\n", " self.save(self.model, epoch, optimizer)\n", " # We will save model summary\n", " self.summary({'loss_history': loss_history})\n", "\n", " # implement the aggregation function, this function will be called by the sever side\n", " def server_aggregate_procedure(self, extra_data={}):\n", " \n", " # initialize aggregator\n", " if self.fed_mode:\n", " aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')\n", "\n", " # the aggregation process is simple: every epoch the server aggregate model and loss once\n", " for i in range(self.epochs):\n", " aggregator.model_aggregation()\n", " merge_loss, _ = aggregator.loss_aggregation()\n", "\n", "\n", " def _predict(self, dataset: Dataset):\n", " len_ = len(dataset)\n", " dl = DataLoader(dataset, batch_size=len_)\n", " preds, labels = None, None\n", " for data, l in dl:\n", " preds = self.model(data)\n", " labels = l\n", " sample_ids = dataset.get_sample_ids()\n", " return sample_ids, preds, labels\n", "\n", " # We implement the predict function here\n", " def predict(self, dataset):\n", " \n", " sample_ids, preds, labels = self._predict(dataset)\n", " return self.format_predict_result(sample_ids, preds, labels, 'binary')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit a Pipeline\n", "\n", "Here we submit a new pipeline to test our new trainer" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-26 16:13:37.650\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212261613368298170\n", "\u001b[0m\n", "\u001b[32m2022-12-26 16:13:37.665\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[0mm2022-12-26 16:13:38.698\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-26 16:13:38.700\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:01\u001b[0m\n", "\u001b[32m2022-12-26 16:13:39.722\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:02\u001b[0m\n", "\u001b[32m2022-12-26 16:13:40.754\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-26 16:13:41.783\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-26 16:13:42.810\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-26 16:13:43.934\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-26 16:13:45.045\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:07\u001b[0m\n", "\u001b[32m2022-12-26 16:13:46.081\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:08\u001b[0m\n", "\u001b[32m2022-12-26 16:13:47.107\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:09\u001b[0m\n", "\u001b[0mm2022-12-26 16:13:48.155\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-26 16:13:48.160\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:10\u001b[0m\n", "\u001b[32m2022-12-26 16:13:49.233\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:11\u001b[0m\n", "\u001b[32m2022-12-26 16:13:50.274\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:12\u001b[0m\n", "\u001b[32m2022-12-26 16:13:51.299\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:13\u001b[0m\n", "\u001b[32m2022-12-26 16:13:52.342\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:14\u001b[0m\n", "\u001b[32m2022-12-26 16:13:53.377\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:15\u001b[0m\n", "\u001b[32m2022-12-26 16:13:54.448\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:16\u001b[0m\n", "\u001b[32m2022-12-26 16:13:55.483\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:17\u001b[0m\n", "\u001b[32m2022-12-26 16:13:56.546\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:18\u001b[0m\n", "\u001b[0mm2022-12-26 16:13:58.785\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-26 16:13:58.791\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:21\u001b[0m\n", "\u001b[32m2022-12-26 16:13:59.883\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:22\u001b[0m\n", "\u001b[32m2022-12-26 16:14:00.921\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:23\u001b[0m\n", "\u001b[32m2022-12-26 16:14:01.997\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:24\u001b[0m\n", "\u001b[32m2022-12-26 16:14:03.025\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:25\u001b[0m\n", "\u001b[32m2022-12-26 16:14:04.053\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:26\u001b[0m\n", "\u001b[32m2022-12-26 16:14:05.094\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:27\u001b[0m\n", "\u001b[32m2022-12-26 16:14:06.134\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:28\u001b[0m\n", "\u001b[32m2022-12-26 16:14:07.175\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:29\u001b[0m\n", "\u001b[32m2022-12-26 16:14:08.221\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:30\u001b[0m\n", "\u001b[32m2022-12-26 16:14:09.268\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:31\u001b[0m\n", "\u001b[32m2022-12-26 16:14:10.292\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:32\u001b[0m\n", "\u001b[32m2022-12-26 16:14:11.476\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:33\u001b[0m\n", "\u001b[32m2022-12-26 16:14:12.497\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:34\u001b[0m\n", "\u001b[32m2022-12-26 16:14:13.542\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:35\u001b[0m\n", "\u001b[32m2022-12-26 16:14:14.577\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:36\u001b[0m\n", "\u001b[32m2022-12-26 16:14:15.608\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:37\u001b[0m\n", "\u001b[32m2022-12-26 16:14:17.720\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212261613368298170\u001b[0m\n", "\u001b[32m2022-12-26 16:14:17.721\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:40\u001b[0m\n" ] } ], "source": [ "# torch\n", "import torch as t\n", "from torch import nn\n", "from pipeline import fate_torch_hook\n", "fate_torch_hook(t)\n", "# pipeline\n", "from pipeline.component.homo_nn import HomoNN, TrainerParam # HomoNN Component, TrainerParam for setting trainer parameter\n", "from pipeline.backend.pipeline import PipeLine # pipeline class\n", "from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation\n", "from pipeline.interface import Data # Data Interaces for defining data flow\n", "\n", "\n", "# create a pipeline to submitting the job\n", "guest = 9999\n", "host = 10000\n", "arbiter = 10000\n", "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)\n", "\n", "# read uploaded dataset\n", "train_data_0 = {\"name\": \"breast_homo_guest\", \"namespace\": \"experiment\"}\n", "train_data_1 = {\"name\": \"breast_homo_host\", \"namespace\": \"experiment\"}\n", "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)\n", "reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)\n", "\n", "# The transform component converts the uploaded data to the DATE standard format\n", "data_transform_0 = DataTransform(name='data_transform_0')\n", "data_transform_0.get_party_instance(\n", " role='guest', party_id=guest).component_param(\n", " with_label=True, output_format=\"dense\")\n", "data_transform_0.get_party_instance(\n", " role='host', party_id=host).component_param(\n", " with_label=True, output_format=\"dense\")\n", "\n", "\"\"\"\n", "Define Pytorch model/ optimizer and loss\n", "\"\"\"\n", "model = nn.Sequential(\n", " nn.Linear(30, 1),\n", " nn.Sigmoid()\n", ")\n", "loss = nn.BCELoss()\n", "optimizer = t.optim.Adam(model.parameters(), lr=0.01)\n", "\n", "\n", "\"\"\"\n", "Create Homo-NN Component\n", "\"\"\"\n", "nn_component = HomoNN(name='nn_0',\n", " model=model, # set model\n", " loss=loss, # set loss\n", " optimizer=optimizer, # set optimizer\n", " # Here we use fedavg trainer\n", " # TrainerParam passes parameters to fedavg_trainer, see below for details about Trainer\n", " trainer=TrainerParam(trainer_name='fedprox_v2', epochs=3, batch_size=128, u=0.5),\n", " torch_seed=100 # random seed\n", " )\n", "\n", "# define work flow\n", "pipeline.add_component(reader_0)\n", "pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))\n", "pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))\n", "pipeline.compile()\n", "pipeline.fit()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "vscode": { "interpreter": { "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853" } } }, "nbformat": 4, "nbformat_minor": 2 }