{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "3ee4aa0e", "metadata": {}, "source": [ "# Homo-NN Quick Start: A Binary Classification Task" ] }, { "attachments": {}, "cell_type": "markdown", "id": "deff6c23", "metadata": {}, "source": [ "This tutorial allows you to quickly get started using Homo-NN. By default, you can use the Homo-NN component in the same process as other FATE algorithm components: use the reader and transformer interfaces that come with FATE to input table data and convert the data format, and then input the data into the algorithm component. Then NN component will use your defined model, optimizer and loss function for training and model aggregation.\n", "\n", "In FATE-1.10, Homo-NN in the pipeline has added support for pytorch. You can follow the usage of pytorch Sequential, use the built-in layers of Pytorch to define the Sequential model and submit the model. At the same time, you can use the loss function and optimizer that comes with Pytorch.\n", "\n", "The following is a basic binary classification task Homo-NN task. There are two clients with party ids of 10000 and 9999 respectively, and 10000 is specified as the server-side aggregation model." ] }, { "attachments": {}, "cell_type": "markdown", "id": "43fd3f74", "metadata": {}, "source": [ "## Uploading Tabular Data\n", "\n", "At the very beginning, we upload data to FATE. We can directly upload data using the pipeline. Here we upload two files: breast_homo_guest.csv for the guest, and breast_homo_host.csv for the host. Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to upload corresponding data on each machine. " ] }, { "cell_type": "code", "execution_count": 7, "id": "611ca3bf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 10:40:32.733\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191040322910830\n", "\u001b[0m\n", "\u001b[32m2022-12-19 10:40:32.747\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[0mm2022-12-19 10:40:33.781\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 10:40:33.788\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:01\u001b[0m\n", "\u001b[32m2022-12-19 10:40:34.810\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:02\u001b[0m\n", "\u001b[32m2022-12-19 10:40:35.835\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 10:40:36.856\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 10:40:37.887\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 10:40:38.912\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 10:40:39.998\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191040322910830\u001b[0m\n", "\u001b[32m2022-12-19 10:40:40.001\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:07\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 10:40:40.706\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191040400256350\n", "\u001b[0m\n", "\u001b[32m2022-12-19 10:40:40.748\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[32m2022-12-19 10:40:41.769\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n", "\u001b[32m2022-12-19 10:40:42.806\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:02\u001b[0m\n", "\u001b[0mm2022-12-19 10:40:43.830\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 10:40:43.832\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 10:40:44.852\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 10:40:45.872\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 10:40:46.893\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 10:40:47.925\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:07\u001b[0m\n", "\u001b[32m2022-12-19 10:40:48.951\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:08\u001b[0m\n", "\u001b[32m2022-12-19 10:40:49.969\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191040400256350\u001b[0m\n", "\u001b[32m2022-12-19 10:40:49.971\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:09\u001b[0m\n" ] } ], "source": [ "from pipeline.backend.pipeline import PipeLine # pipeline Class\n", "\n", "# [9999(guest), 10000(host)] as client\n", "# [10000(arbiter)] as server\n", "\n", "guest = 9999\n", "host = 10000\n", "arbiter = 10000\n", "pipeline_upload = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host, arbiter=arbiter)\n", "\n", "partition = 4\n", "\n", "# upload a dataset\n", "path_to_fate_project = '../../../../'\n", "guest_data = {\"name\": \"breast_homo_guest\", \"namespace\": \"experiment\"}\n", "host_data = {\"name\": \"breast_homo_host\", \"namespace\": \"experiment\"}\n", "pipeline_upload.add_upload_data(file=\"examples/data/breast_homo_guest.csv\", # file in the example/data\n", " table_name=guest_data[\"name\"], # table name\n", " namespace=guest_data[\"namespace\"], # namespace\n", " head=1, partition=partition) # data info\n", "pipeline_upload.add_upload_data(file=\"examples/data/breast_homo_host.csv\", # file in the example/data\n", " table_name=host_data[\"name\"], # table name\n", " namespace=host_data[\"namespace\"], # namespace\n", " head=1, partition=partition) # data info\n", "\n", "\n", "pipeline_upload.upload(drop=1)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "afa1afad", "metadata": {}, "source": [ "The breast dataset is a binary dataset set with 30 features:" ] }, { "cell_type": "code", "execution_count": 11, "id": "d9580f9e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | id | \n", "y | \n", "x0 | \n", "x1 | \n", "x2 | \n", "x3 | \n", "x4 | \n", "x5 | \n", "x6 | \n", "x7 | \n", "... | \n", "x20 | \n", "x21 | \n", "x22 | \n", "x23 | \n", "x24 | \n", "x25 | \n", "x26 | \n", "x27 | \n", "x28 | \n", "x29 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "133 | \n", "1 | \n", "0.254879 | \n", "-1.046633 | \n", "0.209656 | \n", "0.074214 | \n", "-0.441366 | \n", "-0.377645 | \n", "-0.485934 | \n", "0.347072 | \n", "... | \n", "-0.337360 | \n", "-0.728193 | \n", "-0.442587 | \n", "-0.272757 | \n", "-0.608018 | \n", "-0.577235 | \n", "-0.501126 | \n", "0.143371 | \n", "-0.466431 | \n", "-0.554102 | \n", "
1 | \n", "273 | \n", "1 | \n", "-1.142928 | \n", "-0.781198 | \n", "-1.166747 | \n", "-0.923578 | \n", "0.628230 | \n", "-1.021418 | \n", "-1.111867 | \n", "-0.959523 | \n", "... | \n", "-0.493639 | \n", "0.348620 | \n", "-0.552483 | \n", "-0.526877 | \n", "2.253098 | \n", "-0.827620 | \n", "-0.780739 | \n", "-0.376997 | \n", "-0.310239 | \n", "0.176301 | \n", "
2 | \n", "175 | \n", "1 | \n", "-1.451067 | \n", "-1.406518 | \n", "-1.456564 | \n", "-1.092337 | \n", "-0.708765 | \n", "-1.168557 | \n", "-1.305831 | \n", "-1.745063 | \n", "... | \n", "-0.666881 | \n", "-0.779358 | \n", "-0.708418 | \n", "-0.637545 | \n", "0.710369 | \n", "-0.976454 | \n", "-1.057501 | \n", "-1.913447 | \n", "0.795207 | \n", "-0.149751 | \n", "
3 | \n", "551 | \n", "1 | \n", "-0.879933 | \n", "0.420589 | \n", "-0.877527 | \n", "-0.780484 | \n", "-1.037534 | \n", "-0.483880 | \n", "-0.555498 | \n", "-0.768581 | \n", "... | \n", "-0.451772 | \n", "0.453852 | \n", "-0.431696 | \n", "-0.494754 | \n", "-1.182041 | \n", "0.281228 | \n", "0.084759 | \n", "-0.252420 | \n", "1.038575 | \n", "0.351054 | \n", "
4 | \n", "199 | \n", "0 | \n", "0.426758 | \n", "0.723479 | \n", "0.316885 | \n", "0.287273 | \n", "1.000835 | \n", "0.962702 | \n", "1.077099 | \n", "1.053586 | \n", "... | \n", "-0.707304 | \n", "-1.026834 | \n", "-0.702973 | \n", "-0.460212 | \n", "-0.999033 | \n", "-0.531406 | \n", "-0.394360 | \n", "-0.728830 | \n", "-0.644416 | \n", "-0.688003 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
222 | \n", "105 | \n", "0 | \n", "0.008451 | \n", "-0.533675 | \n", "-0.025652 | \n", "-0.093843 | \n", "2.359748 | \n", "0.990056 | \n", "1.753070 | \n", "1.278939 | \n", "... | \n", "-0.051872 | \n", "-0.531700 | \n", "-0.225763 | \n", "-0.124905 | \n", "0.040342 | \n", "0.203542 | \n", "0.757183 | \n", "0.338023 | \n", "-0.614146 | \n", "1.249401 | \n", "
223 | \n", "242 | \n", "1 | \n", "-0.763967 | \n", "0.371736 | \n", "-0.598731 | \n", "-0.716671 | \n", "0.102199 | \n", "1.466524 | \n", "2.261608 | \n", "0.109537 | \n", "... | \n", "-0.586035 | \n", "0.771362 | \n", "-0.246060 | \n", "-0.526877 | \n", "-0.125998 | \n", "1.881346 | \n", "1.886843 | \n", "0.217988 | \n", "-0.071715 | \n", "1.845903 | \n", "
224 | \n", "312 | \n", "1 | \n", "-0.430564 | \n", "-1.510738 | \n", "-0.453377 | \n", "-0.460192 | \n", "-0.568490 | \n", "-0.212884 | \n", "-0.457149 | \n", "-0.464354 | \n", "... | \n", "-0.283944 | \n", "-1.011412 | \n", "-0.257445 | \n", "-0.333482 | \n", "-0.182334 | \n", "0.123061 | \n", "-0.017365 | \n", "-0.179426 | \n", "-0.391362 | \n", "0.225852 | \n", "
225 | \n", "473 | \n", "1 | \n", "-0.583805 | \n", "2.014830 | \n", "-0.660686 | \n", "-0.565491 | \n", "-1.672278 | \n", "-1.285861 | \n", "-1.305831 | \n", "-1.745063 | \n", "... | \n", "0.145552 | \n", "4.409122 | \n", "0.008881 | \n", "-0.114565 | \n", "0.099344 | \n", "-0.963264 | \n", "-1.057501 | \n", "-1.913447 | \n", "1.315845 | \n", "-0.249231 | \n", "
226 | \n", "364 | \n", "1 | \n", "-0.318739 | \n", "-0.647666 | \n", "-0.402145 | \n", "-0.381613 | \n", "-0.485202 | \n", "-0.551311 | \n", "-0.651448 | \n", "-0.681180 | \n", "... | \n", "-0.890652 | \n", "-1.096686 | \n", "-0.905935 | \n", "-0.596622 | \n", "-0.882362 | \n", "-0.725342 | \n", "-0.576392 | \n", "-1.023890 | \n", "-0.924107 | \n", "-0.650934 | \n", "
227 rows × 32 columns
\n", "\n", " | id | \n", "label | \n", "predict_result | \n", "predict_score | \n", "predict_detail | \n", "type | \n", "
---|---|---|---|---|---|---|
0 | \n", "0 | \n", "0.0 | \n", "0 | \n", "0.05885133519768715 | \n", "{'0': 0.9411486648023129, '1': 0.0588513351976... | \n", "train | \n", "
1 | \n", "3 | \n", "0.0 | \n", "1 | \n", "0.5971069931983948 | \n", "{'0': 0.4028930068016052, '1': 0.5971069931983... | \n", "train | \n", "
2 | \n", "5 | \n", "1.0 | \n", "1 | \n", "0.7218729257583618 | \n", "{'0': 0.2781270742416382, '1': 0.7218729257583... | \n", "train | \n", "
3 | \n", "7 | \n", "1.0 | \n", "1 | \n", "0.6514894962310791 | \n", "{'0': 0.3485105037689209, '1': 0.6514894962310... | \n", "train | \n", "
4 | \n", "14 | \n", "0.0 | \n", "0 | \n", "0.2351398915052414 | \n", "{'0': 0.7648601084947586, '1': 0.2351398915052... | \n", "train | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
222 | \n", "551 | \n", "1.0 | \n", "0 | \n", "0.38658156991004944 | \n", "{'0': 0.6134184300899506, '1': 0.3865815699100... | \n", "train | \n", "
223 | \n", "559 | \n", "0.0 | \n", "1 | \n", "0.5517507195472717 | \n", "{'0': 0.44824928045272827, '1': 0.551750719547... | \n", "train | \n", "
224 | \n", "562 | \n", "0.0 | \n", "0 | \n", "0.39873841404914856 | \n", "{'0': 0.6012615859508514, '1': 0.3987384140491... | \n", "train | \n", "
225 | \n", "567 | \n", "1.0 | \n", "1 | \n", "0.6306618452072144 | \n", "{'0': 0.36933815479278564, '1': 0.630661845207... | \n", "train | \n", "
226 | \n", "568 | \n", "0.0 | \n", "1 | \n", "0.5063760876655579 | \n", "{'0': 0.49362391233444214, '1': 0.506376087665... | \n", "train | \n", "
227 rows × 6 columns
\n", "