{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "3ee4aa0e", "metadata": {}, "source": [ "# Homo-NN Quick Start: A Binary Classification Task" ] }, { "attachments": {}, "cell_type": "markdown", "id": "deff6c23", "metadata": {}, "source": [ "This tutorial allows you to quickly get started using Homo-NN. By default, you can use the Homo-NN component in the same process as other FATE algorithm components: use the reader and transformer interfaces that come with FATE to input table data and convert the data format, and then input the data into the algorithm component. Then NN component will use your defined model, optimizer and loss function for training and model aggregation.\n", "\n", "In FATE-1.10, Homo-NN in the pipeline has added support for pytorch. You can follow the usage of pytorch Sequential, use the built-in layers of Pytorch to define the Sequential model and submit the model. At the same time, you can use the loss function and optimizer that comes with Pytorch.\n", "\n", "The following is a basic binary classification task Homo-NN task. There are two clients with party ids of 10000 and 9999 respectively, and 10000 is specified as the server-side aggregation model." ] }, { "attachments": {}, "cell_type": "markdown", "id": "43fd3f74", "metadata": {}, "source": [ "## Uploading Tabular Data\n", "\n", "At the very beginning, we upload data to FATE. We can directly upload data using the pipeline. Here we upload two files: breast_homo_guest.csv for the guest, and breast_homo_host.csv for the host. Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to upload corresponding data on each machine. " ] }, { "cell_type": "code", "execution_count": 7, "id": "611ca3bf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 10:40:32.733\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191040322910830\n", "\u001b[0m\n", "\u001b[32m2022-12-19 10:40:32.747\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[0mm2022-12-19 10:40:33.781\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 10:40:33.788\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:01\u001b[0m\n", "\u001b[32m2022-12-19 10:40:34.810\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:02\u001b[0m\n", "\u001b[32m2022-12-19 10:40:35.835\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 10:40:36.856\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 10:40:37.887\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 10:40:38.912\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 10:40:39.998\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191040322910830\u001b[0m\n", "\u001b[32m2022-12-19 10:40:40.001\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:07\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 10:40:40.706\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191040400256350\n", "\u001b[0m\n", "\u001b[32m2022-12-19 10:40:40.748\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[32m2022-12-19 10:40:41.769\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n", "\u001b[32m2022-12-19 10:40:42.806\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:02\u001b[0m\n", "\u001b[0mm2022-12-19 10:40:43.830\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 10:40:43.832\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 10:40:44.852\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 10:40:45.872\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 10:40:46.893\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 10:40:47.925\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:07\u001b[0m\n", "\u001b[32m2022-12-19 10:40:48.951\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:08\u001b[0m\n", "\u001b[32m2022-12-19 10:40:49.969\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191040400256350\u001b[0m\n", "\u001b[32m2022-12-19 10:40:49.971\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:09\u001b[0m\n" ] } ], "source": [ "from pipeline.backend.pipeline import PipeLine # pipeline Class\n", "\n", "# [9999(guest), 10000(host)] as client\n", "# [10000(arbiter)] as server\n", "\n", "guest = 9999\n", "host = 10000\n", "arbiter = 10000\n", "pipeline_upload = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)\n", "\n", "partition = 4\n", "\n", "# upload a dataset\n", "path_to_fate_project = '../../../../'\n", "guest_data = {\"name\": \"breast_homo_guest\", \"namespace\": \"experiment\"}\n", "host_data = {\"name\": \"breast_homo_host\", \"namespace\": \"experiment\"}\n", "pipeline_upload.add_upload_data(file=\"examples/data/breast_homo_guest.csv\", # file in the example/data\n", " table_name=guest_data[\"name\"], # table name\n", " namespace=guest_data[\"namespace\"], # namespace\n", " head=1, partition=partition) # data info\n", "pipeline_upload.add_upload_data(file=\"examples/data/breast_homo_host.csv\", # file in the example/data\n", " table_name=host_data[\"name\"], # table name\n", " namespace=host_data[\"namespace\"], # namespace\n", " head=1, partition=partition) # data info\n", "\n", "\n", "pipeline_upload.upload(drop=1)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "afa1afad", "metadata": {}, "source": [ "The breast dataset is a binary dataset set with 30 features:" ] }, { "cell_type": "code", "execution_count": 11, "id": "d9580f9e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idyx0x1x2x3x4x5x6x7...x20x21x22x23x24x25x26x27x28x29
013310.254879-1.0466330.2096560.074214-0.441366-0.377645-0.4859340.347072...-0.337360-0.728193-0.442587-0.272757-0.608018-0.577235-0.5011260.143371-0.466431-0.554102
12731-1.142928-0.781198-1.166747-0.9235780.628230-1.021418-1.111867-0.959523...-0.4936390.348620-0.552483-0.5268772.253098-0.827620-0.780739-0.376997-0.3102390.176301
21751-1.451067-1.406518-1.456564-1.092337-0.708765-1.168557-1.305831-1.745063...-0.666881-0.779358-0.708418-0.6375450.710369-0.976454-1.057501-1.9134470.795207-0.149751
35511-0.8799330.420589-0.877527-0.780484-1.037534-0.483880-0.555498-0.768581...-0.4517720.453852-0.431696-0.494754-1.1820410.2812280.084759-0.2524201.0385750.351054
419900.4267580.7234790.3168850.2872731.0008350.9627021.0770991.053586...-0.707304-1.026834-0.702973-0.460212-0.999033-0.531406-0.394360-0.728830-0.644416-0.688003
..................................................................
22210500.008451-0.533675-0.025652-0.0938432.3597480.9900561.7530701.278939...-0.051872-0.531700-0.225763-0.1249050.0403420.2035420.7571830.338023-0.6141461.249401
2232421-0.7639670.371736-0.598731-0.7166710.1021991.4665242.2616080.109537...-0.5860350.771362-0.246060-0.526877-0.1259981.8813461.8868430.217988-0.0717151.845903
2243121-0.430564-1.510738-0.453377-0.460192-0.568490-0.212884-0.457149-0.464354...-0.283944-1.011412-0.257445-0.333482-0.1823340.123061-0.017365-0.179426-0.3913620.225852
2254731-0.5838052.014830-0.660686-0.565491-1.672278-1.285861-1.305831-1.745063...0.1455524.4091220.008881-0.1145650.099344-0.963264-1.057501-1.9134471.315845-0.249231
2263641-0.318739-0.647666-0.402145-0.381613-0.485202-0.551311-0.651448-0.681180...-0.890652-1.096686-0.905935-0.596622-0.882362-0.725342-0.576392-1.023890-0.924107-0.650934
\n", "

227 rows × 32 columns

\n", "
" ], "text/plain": [ " id y x0 x1 x2 x3 x4 x5 \\\n", "0 133 1 0.254879 -1.046633 0.209656 0.074214 -0.441366 -0.377645 \n", "1 273 1 -1.142928 -0.781198 -1.166747 -0.923578 0.628230 -1.021418 \n", "2 175 1 -1.451067 -1.406518 -1.456564 -1.092337 -0.708765 -1.168557 \n", "3 551 1 -0.879933 0.420589 -0.877527 -0.780484 -1.037534 -0.483880 \n", "4 199 0 0.426758 0.723479 0.316885 0.287273 1.000835 0.962702 \n", ".. ... .. ... ... ... ... ... ... \n", "222 105 0 0.008451 -0.533675 -0.025652 -0.093843 2.359748 0.990056 \n", "223 242 1 -0.763967 0.371736 -0.598731 -0.716671 0.102199 1.466524 \n", "224 312 1 -0.430564 -1.510738 -0.453377 -0.460192 -0.568490 -0.212884 \n", "225 473 1 -0.583805 2.014830 -0.660686 -0.565491 -1.672278 -1.285861 \n", "226 364 1 -0.318739 -0.647666 -0.402145 -0.381613 -0.485202 -0.551311 \n", "\n", " x6 x7 ... x20 x21 x22 x23 \\\n", "0 -0.485934 0.347072 ... -0.337360 -0.728193 -0.442587 -0.272757 \n", "1 -1.111867 -0.959523 ... -0.493639 0.348620 -0.552483 -0.526877 \n", "2 -1.305831 -1.745063 ... -0.666881 -0.779358 -0.708418 -0.637545 \n", "3 -0.555498 -0.768581 ... -0.451772 0.453852 -0.431696 -0.494754 \n", "4 1.077099 1.053586 ... -0.707304 -1.026834 -0.702973 -0.460212 \n", ".. ... ... ... ... ... ... ... \n", "222 1.753070 1.278939 ... -0.051872 -0.531700 -0.225763 -0.124905 \n", "223 2.261608 0.109537 ... -0.586035 0.771362 -0.246060 -0.526877 \n", "224 -0.457149 -0.464354 ... -0.283944 -1.011412 -0.257445 -0.333482 \n", "225 -1.305831 -1.745063 ... 0.145552 4.409122 0.008881 -0.114565 \n", "226 -0.651448 -0.681180 ... -0.890652 -1.096686 -0.905935 -0.596622 \n", "\n", " x24 x25 x26 x27 x28 x29 \n", "0 -0.608018 -0.577235 -0.501126 0.143371 -0.466431 -0.554102 \n", "1 2.253098 -0.827620 -0.780739 -0.376997 -0.310239 0.176301 \n", "2 0.710369 -0.976454 -1.057501 -1.913447 0.795207 -0.149751 \n", "3 -1.182041 0.281228 0.084759 -0.252420 1.038575 0.351054 \n", "4 -0.999033 -0.531406 -0.394360 -0.728830 -0.644416 -0.688003 \n", ".. ... ... ... ... ... ... \n", "222 0.040342 0.203542 0.757183 0.338023 -0.614146 1.249401 \n", "223 -0.125998 1.881346 1.886843 0.217988 -0.071715 1.845903 \n", "224 -0.182334 0.123061 -0.017365 -0.179426 -0.391362 0.225852 \n", "225 0.099344 -0.963264 -1.057501 -1.913447 1.315845 -0.249231 \n", "226 -0.882362 -0.725342 -0.576392 -1.023890 -0.924107 -0.650934 \n", "\n", "[227 rows x 32 columns]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df = pd.read_csv('../../../../examples/data/breast_homo_guest.csv')\n", "df" ] }, { "attachments": {}, "cell_type": "markdown", "id": "53e6fc84", "metadata": {}, "source": [ "## Write the Pipeline script and execute it\n", "\n", "After the upload is complete, we can start writing the pipeline script to submit a FATE task." ] }, { "attachments": {}, "cell_type": "markdown", "id": "9a5ccfe0", "metadata": {}, "source": [ "### Import Pipeline Components" ] }, { "cell_type": "code", "execution_count": 8, "id": "d4eec107", "metadata": { "pycharm": { "is_executing": true } }, "outputs": [], "source": [ "# torch\n", "import torch as t\n", "from torch import nn\n", "\n", "# pipeline\n", "from pipeline.component.homo_nn import HomoNN, TrainerParam # HomoNN Component, TrainerParam for setting trainer parameter\n", "from pipeline.backend.pipeline import PipeLine # pipeline class\n", "from pipeline.component import Reader, DataTransform, Evaluation # Data I/O and Evaluation\n", "from pipeline.interface import Data # Data Interaces for defining data flow" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b7f7d9d5", "metadata": {}, "source": [ "We can check the parameters of the Homo-NN component:" ] }, { "cell_type": "code", "execution_count": 19, "id": "fbdf01b1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", " Parameters\n", " ----------\n", " name, name of this component\n", " trainer, trainer param\n", " dataset, dataset param\n", " torch_seed, global random seed\n", " loss, loss function from fate_torch\n", " optimizer, optimizer from fate_torch\n", " model, a fate torch sequential defining the model structure\n", " \n" ] } ], "source": [ "print(HomoNN.__doc__)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "48edc92d", "metadata": {}, "source": [ "### fate_torch_hook\n", "\n", "Please be sure to execute the following fate_torch_hook function, which can modify some classes of torch, so that the torch layers, sequential, optimizer, and loss function you define in the scripts can be parsed and submitted by the pipeline. " ] }, { "cell_type": "code", "execution_count": 9, "id": "955db238", "metadata": {}, "outputs": [], "source": [ "from pipeline import fate_torch_hook\n", "t = fate_torch_hook(t)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ee5be800", "metadata": {}, "source": [ "### pipeline" ] }, { "cell_type": "code", "execution_count": 12, "id": "cc9a174a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 10:50:38.106\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191050351408970\n", "\u001b[0m\n", "\u001b[32m2022-12-19 10:50:38.118\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[32m2022-12-19 10:50:39.139\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n", "\u001b[0mm2022-12-19 10:50:40.176\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 10:50:40.183\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:02\u001b[0m\n", "\u001b[32m2022-12-19 10:50:41.217\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 10:50:42.256\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 10:50:43.287\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 10:50:44.330\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 10:50:45.357\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:07\u001b[0m\n", "\u001b[0mm2022-12-19 10:50:47.457\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 10:50:47.459\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:09\u001b[0m\n", "\u001b[32m2022-12-19 10:50:48.484\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:10\u001b[0m\n", "\u001b[32m2022-12-19 10:50:49.512\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:11\u001b[0m\n", "\u001b[32m2022-12-19 10:50:50.539\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:12\u001b[0m\n", "\u001b[32m2022-12-19 10:50:51.590\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:13\u001b[0m\n", "\u001b[32m2022-12-19 10:50:52.624\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:14\u001b[0m\n", "\u001b[32m2022-12-19 10:50:53.677\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:15\u001b[0m\n", "\u001b[32m2022-12-19 10:50:54.699\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:16\u001b[0m\n", "\u001b[32m2022-12-19 10:50:55.723\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:17\u001b[0m\n", "\u001b[0mm2022-12-19 10:50:56.782\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 10:50:56.785\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:18\u001b[0m\n", "\u001b[32m2022-12-19 10:50:57.814\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:19\u001b[0m\n", "\u001b[32m2022-12-19 10:50:58.839\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:20\u001b[0m\n", "\u001b[32m2022-12-19 10:50:59.865\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:21\u001b[0m\n", "\u001b[32m2022-12-19 10:51:00.898\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:22\u001b[0m\n", "\u001b[32m2022-12-19 10:51:01.925\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:23\u001b[0m\n", "\u001b[32m2022-12-19 10:51:02.953\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:24\u001b[0m\n", "\u001b[32m2022-12-19 10:51:04.020\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:25\u001b[0m\n", "\u001b[32m2022-12-19 10:51:05.055\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:26\u001b[0m\n", "\u001b[32m2022-12-19 10:51:06.080\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:27\u001b[0m\n", "\u001b[32m2022-12-19 10:51:07.132\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:29\u001b[0m\n", "\u001b[32m2022-12-19 10:51:08.164\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:30\u001b[0m\n", "\u001b[32m2022-12-19 10:51:09.207\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:31\u001b[0m\n", "\u001b[32m2022-12-19 10:51:10.237\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:32\u001b[0m\n", "\u001b[32m2022-12-19 10:51:11.258\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:33\u001b[0m\n", "\u001b[0mm2022-12-19 10:51:13.377\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 10:51:13.379\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:35\u001b[0m\n", "\u001b[32m2022-12-19 10:51:14.400\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:36\u001b[0m\n", "\u001b[32m2022-12-19 10:51:15.439\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:37\u001b[0m\n", "\u001b[32m2022-12-19 10:51:16.469\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:38\u001b[0m\n", "\u001b[32m2022-12-19 10:51:17.491\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:39\u001b[0m\n", "\u001b[32m2022-12-19 10:51:18.564\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:40\u001b[0m\n", "\u001b[32m2022-12-19 10:51:19.589\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:41\u001b[0m\n", "\u001b[32m2022-12-19 10:51:20.626\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:42\u001b[0m\n", "\u001b[32m2022-12-19 10:51:21.650\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:43\u001b[0m\n", "\u001b[32m2022-12-19 10:51:23.691\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191050351408970\u001b[0m\n", "\u001b[32m2022-12-19 10:51:23.693\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:45\u001b[0m\n" ] } ], "source": [ "# create a pipeline to submitting the job\n", "guest = 9999\n", "host = 10000\n", "arbiter = 10000\n", "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)\n", "\n", "# read uploaded dataset\n", "train_data_0 = {\"name\": \"breast_homo_guest\", \"namespace\": \"experiment\"}\n", "train_data_1 = {\"name\": \"breast_homo_host\", \"namespace\": \"experiment\"}\n", "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=train_data_0)\n", "reader_0.get_party_instance(role='host', party_id=host).component_param(table=train_data_1)\n", "\n", "# The transform component converts the uploaded data to the DATE standard format\n", "data_transform_0 = DataTransform(name='data_transform_0')\n", "data_transform_0.get_party_instance(\n", " role='guest', party_id=guest).component_param(\n", " with_label=True, output_format=\"dense\")\n", "data_transform_0.get_party_instance(\n", " role='host', party_id=host).component_param(\n", " with_label=True, output_format=\"dense\")\n", "\n", "\"\"\"\n", "Define Pytorch model/ optimizer and loss\n", "\"\"\"\n", "model = nn.Sequential(\n", " nn.Linear(30, 1),\n", " nn.Sigmoid()\n", ")\n", "loss = nn.BCELoss()\n", "optimizer = t.optim.Adam(model.parameters(), lr=0.01)\n", "\n", "\n", "\"\"\"\n", "Create Homo-NN Component\n", "\"\"\"\n", "nn_component = HomoNN(name='nn_0',\n", " model=model, # set model\n", " loss=loss, # set loss\n", " optimizer=optimizer, # set optimizer\n", " # Here we use fedavg trainer\n", " # TrainerParam passes parameters to fedavg_trainer, see below for details about Trainer\n", " trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=3, batch_size=128, validation_freqs=1),\n", " torch_seed=100 # random seed\n", " )\n", "\n", "# define work flow\n", "pipeline.add_component(reader_0)\n", "pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))\n", "pipeline.add_component(nn_component, data=Data(train_data=data_transform_0.output.data))\n", "pipeline.add_component(Evaluation(name='eval_0'), data=Data(data=nn_component.output.data))\n", "\n", "pipeline.compile()\n", "pipeline.fit()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "af94b45d", "metadata": {}, "source": [ "## Get Component Output" ] }, { "cell_type": "code", "execution_count": 13, "id": "4f7325c1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idlabelpredict_resultpredict_scorepredict_detailtype
000.000.05885133519768715{'0': 0.9411486648023129, '1': 0.0588513351976...train
130.010.5971069931983948{'0': 0.4028930068016052, '1': 0.5971069931983...train
251.010.7218729257583618{'0': 0.2781270742416382, '1': 0.7218729257583...train
371.010.6514894962310791{'0': 0.3485105037689209, '1': 0.6514894962310...train
4140.000.2351398915052414{'0': 0.7648601084947586, '1': 0.2351398915052...train
.....................
2225511.000.38658156991004944{'0': 0.6134184300899506, '1': 0.3865815699100...train
2235590.010.5517507195472717{'0': 0.44824928045272827, '1': 0.551750719547...train
2245620.000.39873841404914856{'0': 0.6012615859508514, '1': 0.3987384140491...train
2255671.010.6306618452072144{'0': 0.36933815479278564, '1': 0.630661845207...train
2265680.010.5063760876655579{'0': 0.49362391233444214, '1': 0.506376087665...train
\n", "

227 rows × 6 columns

\n", "
" ], "text/plain": [ " id label predict_result predict_score \\\n", "0 0 0.0 0 0.05885133519768715 \n", "1 3 0.0 1 0.5971069931983948 \n", "2 5 1.0 1 0.7218729257583618 \n", "3 7 1.0 1 0.6514894962310791 \n", "4 14 0.0 0 0.2351398915052414 \n", ".. ... ... ... ... \n", "222 551 1.0 0 0.38658156991004944 \n", "223 559 0.0 1 0.5517507195472717 \n", "224 562 0.0 0 0.39873841404914856 \n", "225 567 1.0 1 0.6306618452072144 \n", "226 568 0.0 1 0.5063760876655579 \n", "\n", " predict_detail type \n", "0 {'0': 0.9411486648023129, '1': 0.0588513351976... train \n", "1 {'0': 0.4028930068016052, '1': 0.5971069931983... train \n", "2 {'0': 0.2781270742416382, '1': 0.7218729257583... train \n", "3 {'0': 0.3485105037689209, '1': 0.6514894962310... train \n", "4 {'0': 0.7648601084947586, '1': 0.2351398915052... train \n", ".. ... ... \n", "222 {'0': 0.6134184300899506, '1': 0.3865815699100... train \n", "223 {'0': 0.44824928045272827, '1': 0.551750719547... train \n", "224 {'0': 0.6012615859508514, '1': 0.3987384140491... train \n", "225 {'0': 0.36933815479278564, '1': 0.630661845207... train \n", "226 {'0': 0.49362391233444214, '1': 0.506376087665... train \n", "\n", "[227 rows x 6 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# get predict scores\n", "pipeline.get_component('nn_0').get_output_data()" ] }, { "cell_type": "code", "execution_count": 14, "id": "ab8afcfd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'best_epoch': 2,\n", " 'loss_history': [0.8317702709315632, 0.683187778825802, 0.5690162255375396],\n", " 'metrics_summary': {'train': {'auc': [0.732987012987013,\n", " 0.9094372294372294,\n", " 0.9561904761904763],\n", " 'ks': [0.4153246753246753, 0.6851948051948051, 0.7908225108225109]}},\n", " 'need_stop': False}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# get summary\n", "pipeline.get_component('nn_0').get_summary()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7d1df84b", "metadata": {}, "source": [ "## TrainerParam trainer parameter and trainer" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7c4a6605", "metadata": {}, "source": [ "In this version, Homo-NN's training logic and federated aggregation logic are all implemented in the Trainer class. fedavg_trainer is the default Trainer of FATE Homo-NN, which implements the standard fedavg algorithm. And the function of TrainerParam is:\n", "\n", "- Use trainer_name='{module name}' to specify the trainer to use. The trainer is in the federatedml.nn.homo.trainer directory, so you can customize your own trainer. There will be a special chapter for the tutorial on customizing the trainer\n", "- The remaining parameters will be passed to the \\_\\_init\\_\\_() interface of the trainer\n", "\n", "We can check the parameters of fedavg_trainer in FATE, these available parameters can be filled in TrainerParam." ] }, { "cell_type": "code", "execution_count": 16, "id": "d742b424", "metadata": {}, "outputs": [], "source": [ "from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a0e9a681", "metadata": {}, "source": [ "Check the documentation of FedAVGTrainer to learn about the available parameters. When submitting tasks, these parameters can be passed with TrainerParam" ] }, { "cell_type": "code", "execution_count": 17, "id": "041f1937", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", " Parameters\n", " ----------\n", " epochs: int >0, epochs to train\n", " batch_size: int, -1 means full batch\n", " secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number\n", " mask to local models. These random number masks will eventually cancel out to get 0.\n", " weighted_aggregation: bool, whether add weight to each local model when doing aggregation.\n", " if True, According to origin paper, weight of a client is: n_local / n_global, where n_local\n", " is the sample number locally and n_global is the sample number of all clients.\n", " if False, simply averaging these models.\n", "\n", " early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between\n", " two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,\n", " stop training\n", " tol: float, tol value for early stop\n", "\n", " aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate\n", " every n epochs.\n", " cuda: bool, use cuda or not\n", " pin_memory: bool, for pytorch DataLoader\n", " shuffle: bool, for pytorch DataLoader\n", " data_loader_worker: int, for pytorch DataLoader, number of workers when loading data\n", " validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.\n", " if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'\n", " if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'\n", " if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'\n", " checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.\n", " task_type: str, 'auto', 'binary', 'multi', 'regression'\n", " this option decides the return format of this trainer, and the evaluation type when running validation.\n", " if auto, will automatically infer your task type from labels and predict results.\n", " \n" ] } ], "source": [ "print(FedAVGTrainer.__doc__)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "bde0e5e8", "metadata": {}, "source": [ "So far, we have gained a basic understanding of Homo-NN and have utilized it to perform basic modeling tasks. In addition, Homo-NN offers the ability to customize models, datasets, and Trainers for more advanced use cases. For further information, refer to the additional tutorials provided" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "vscode": { "interpreter": { "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853" } } }, "nbformat": 4, "nbformat_minor": 5 }