{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "0b529eab", "metadata": {}, "source": [ "# Hetero-NN Quick Start: A Binary Classification Task" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6f2a3025", "metadata": {}, "source": [ "In this tutorial, you will learn how to use Hetero-NN. It should be noted that Hetero-NN has also been upgraded to work similarly to Homo-NN, allowing for high customization of both models and datasets using the Pytorch backend. We will cover customization in a later chapter specifically for Hetero-NN.\n", "\n", "Additionally, Hetero-NN has also improved some interfaces, such as the Interactive-layer interface, which makes the logic of its usage clearer.\n", "\n", "In this chapter, we will provide an example of a basic binary classification task using Hetero-NN. The process of using this algorithm is consistent with other FATE algorithms: you will use the reader and transformer interfaces provided by FATE to input table data, and then input the data into the algorithm component. The component will then use the defined top/bottom model, optimizer, and loss function for training. The usage of this version is basically the same as the usage of the old version of FATE.\n", "\n", "If you want to understand the principle of the Hetero-NN algorithm, you can refer to doc/federated_component/hetero_nn.md." ] }, { "attachments": {}, "cell_type": "markdown", "id": "485228ca", "metadata": {}, "source": [ "## Uploading Tabular Data\n", "\n", "At the very beginning, we upload data to FATE. We can directly upload data using the pipeline. Here we upload two files: breast_hetero_guest.csv for the guest, and breast_hetero_host.csv for the host. Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to upload corresponding data on each machine. " ] }, { "cell_type": "code", "execution_count": 1, "id": "7cb9620a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 11:28:35.529\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191128349034250\n", "\u001b[0m\n", "\u001b[32m2022-12-19 11:28:35.542\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[32m2022-12-19 11:28:36.557\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n", "\u001b[32m2022-12-19 11:28:37.573\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:02\u001b[0m\n", "\u001b[0mm2022-12-19 11:28:38.594\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 11:28:38.595\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 11:28:39.616\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 11:28:40.643\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 11:28:41.670\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 11:28:42.697\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:07\u001b[0m\n", "\u001b[32m2022-12-19 11:28:43.719\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:08\u001b[0m\n", "\u001b[32m2022-12-19 11:28:44.732\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191128349034250\u001b[0m\n", "\u001b[32m2022-12-19 11:28:44.745\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:09\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " UPLOADING:||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||100.00%\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 11:28:45.106\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191128447632710\n", "\u001b[0m\n", "\u001b[32m2022-12-19 11:28:45.118\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[32m2022-12-19 11:28:46.133\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n", "\u001b[0mm2022-12-19 11:28:47.159\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 11:28:47.161\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:02\u001b[0m\n", "\u001b[32m2022-12-19 11:28:48.185\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 11:28:49.205\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 11:28:50.227\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 11:28:51.250\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 11:28:52.271\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component upload_0, time elapse: 0:00:07\u001b[0m\n", "\u001b[32m2022-12-19 11:28:53.291\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191128447632710\u001b[0m\n", "\u001b[32m2022-12-19 11:28:53.294\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:08\u001b[0m\n" ] } ], "source": [ "from pipeline.backend.pipeline import PipeLine # pipeline class\n", "\n", "# we have two party: guest, whose data with labels\n", "# host, without label\n", "# the dataset is vertically split\n", "\n", "dense_data_guest = {\"name\": \"breast_hetero_guest\", \"namespace\": f\"experiment\"}\n", "dense_data_host = {\"name\": \"breast_hetero_host\", \"namespace\": f\"experiment\"}\n", "\n", "guest= 9999\n", "host = 10000\n", "\n", "pipeline_upload = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)\n", "\n", "partition = 4\n", "\n", "# 上传一份数据\n", "pipeline_upload.add_upload_data(file=\"./examples/data/breast_hetero_guest.csv\",\n", " table_name=dense_data_guest[\"name\"], # table name\n", " namespace=dense_data_guest[\"namespace\"], # namespace\n", " head=1, partition=partition) # data info\n", "\n", "pipeline_upload.add_upload_data(file=\"./examples/data/breast_hetero_host.csv\",\n", " table_name=dense_data_host[\"name\"],\n", " namespace=dense_data_host[\"namespace\"],\n", " head=1, partition=partition) # data info\n", "\n", "pipeline_upload.upload(drop=1)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ff9cbea4", "metadata": {}, "source": [ "The breast dataset is a binary dataset set with 30 features, and it is vertically split:\n", "guest holds 10 fetureas and label, while host holds 20 features" ] }, { "cell_type": "code", "execution_count": 2, "id": "f0ef786c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idyx0x1x2x3x4x5x6x7x8x9
013310.254879-1.0466330.2096560.074214-0.441366-0.377645-0.4859340.347072-0.287570-0.733474
12731-1.142928-0.781198-1.166747-0.9235780.628230-1.021418-1.111867-0.959523-0.096672-0.121683
21751-1.451067-1.406518-1.456564-1.092337-0.708765-1.168557-1.305831-1.745063-0.499499-0.302893
35511-0.8799330.420589-0.877527-0.780484-1.037534-0.483880-0.555498-0.7685810.433960-0.200928
419900.4267580.7234790.3168850.2872731.0008350.9627021.0770991.0535862.9965250.961696
.......................................
5645291-0.583805-1.613330-0.605880-0.5813120.864944-0.579301-0.527672-0.619360-0.193738-0.189844
565400-0.0702400.744648-0.141817-0.162929-1.006849-0.317847-0.305547-0.0518650.150849-0.691912
5661151-0.5382470.076989-0.587413-0.5231250.772888-0.091382-0.584763-0.641591-0.7486370.081139
567201.511870-0.0239741.3474751.4562850.5274071.0829320.8549741.9550001.1522550.201391
568390-0.1530730.0558190.001155-0.2464301.2550831.0702091.1073241.693103-0.1516761.283108
\n", "

569 rows × 12 columns

\n", "
" ], "text/plain": [ " id y x0 x1 x2 x3 x4 x5 \\\n", "0 133 1 0.254879 -1.046633 0.209656 0.074214 -0.441366 -0.377645 \n", "1 273 1 -1.142928 -0.781198 -1.166747 -0.923578 0.628230 -1.021418 \n", "2 175 1 -1.451067 -1.406518 -1.456564 -1.092337 -0.708765 -1.168557 \n", "3 551 1 -0.879933 0.420589 -0.877527 -0.780484 -1.037534 -0.483880 \n", "4 199 0 0.426758 0.723479 0.316885 0.287273 1.000835 0.962702 \n", ".. ... .. ... ... ... ... ... ... \n", "564 529 1 -0.583805 -1.613330 -0.605880 -0.581312 0.864944 -0.579301 \n", "565 40 0 -0.070240 0.744648 -0.141817 -0.162929 -1.006849 -0.317847 \n", "566 115 1 -0.538247 0.076989 -0.587413 -0.523125 0.772888 -0.091382 \n", "567 2 0 1.511870 -0.023974 1.347475 1.456285 0.527407 1.082932 \n", "568 39 0 -0.153073 0.055819 0.001155 -0.246430 1.255083 1.070209 \n", "\n", " x6 x7 x8 x9 \n", "0 -0.485934 0.347072 -0.287570 -0.733474 \n", "1 -1.111867 -0.959523 -0.096672 -0.121683 \n", "2 -1.305831 -1.745063 -0.499499 -0.302893 \n", "3 -0.555498 -0.768581 0.433960 -0.200928 \n", "4 1.077099 1.053586 2.996525 0.961696 \n", ".. ... ... ... ... \n", "564 -0.527672 -0.619360 -0.193738 -0.189844 \n", "565 -0.305547 -0.051865 0.150849 -0.691912 \n", "566 -0.584763 -0.641591 -0.748637 0.081139 \n", "567 0.854974 1.955000 1.152255 0.201391 \n", "568 1.107324 1.693103 -0.151676 1.283108 \n", "\n", "[569 rows x 12 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df = pd.read_csv('../../../../examples/data/breast_hetero_guest.csv')\n", "df" ] }, { "cell_type": "code", "execution_count": 3, "id": "9a1b4f5c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idx0x1x2x3x4x5x6x7x8...x10x11x12x13x14x15x16x17x18x19
01330.449512-1.2472260.4131780.303781-0.123848-0.184227-0.2190760.2685370.015996...-0.337360-0.728193-0.442587-0.272757-0.608018-0.577235-0.5011260.143371-0.466431-0.554102
1273-1.245485-0.842317-1.255026-1.038066-0.426301-1.088781-0.976392-0.8988980.983496...-0.4936390.348620-0.552483-0.5268772.253098-0.827620-0.780739-0.376997-0.3102390.176301
2175-1.549664-1.126219-1.546652-1.216392-0.354424-1.167051-1.114873-1.261820-0.327193...-0.666881-0.779358-0.708418-0.6375450.710369-0.976454-1.057501-1.9134470.795207-0.149751
3551-0.8512730.733108-0.843535-0.786363-0.049836-0.424532-0.509221-0.6796490.797298...-0.4517720.453852-0.431696-0.494754-1.1820410.2812280.084759-0.2524201.0385750.351054
41990.0916540.2164990.103839-0.0346670.1679300.3081320.3666140.2806610.505223...-0.707304-1.026834-0.702973-0.460212-0.999033-0.531406-0.394360-0.728830-0.644416-0.688003
..................................................................
564529-0.584300-1.361252-0.582390-0.5963770.970677-0.270077-0.640169-0.540104-0.564504...-0.555357-1.293361-0.570305-0.4795730.095344-0.779555-0.461337-0.618041-0.111671-0.590414
56540-0.1952010.532980-0.238451-0.261342-1.048999-0.834452-0.724413-0.737944-0.100834...-0.601554-0.708235-0.640599-0.435790-1.253710-0.808059-0.596618-0.797283-0.816347-0.948996
566115-0.6240620.521345-0.635937-0.6151480.093918-0.489914-0.697043-0.743876-0.451325...-0.336999-0.533695-0.428726-0.3420620.254017-0.022811-0.449069-0.662649-0.9398480.023110
56721.5798880.4561871.5665031.5588840.9422101.0529261.3634782.0372310.939685...1.228676-0.7800830.8509281.181336-0.2970050.8149740.2130761.4248270.2370360.293559
56839-0.1838400.356123-0.147009-0.2721500.3728870.4009950.2197210.141115-0.334494...-0.693589-1.134788-0.653965-0.480013-0.558016-0.172595-0.0465430.133639-0.819980-0.229940
\n", "

569 rows × 21 columns

\n", "
" ], "text/plain": [ " id x0 x1 x2 x3 x4 x5 \\\n", "0 133 0.449512 -1.247226 0.413178 0.303781 -0.123848 -0.184227 \n", "1 273 -1.245485 -0.842317 -1.255026 -1.038066 -0.426301 -1.088781 \n", "2 175 -1.549664 -1.126219 -1.546652 -1.216392 -0.354424 -1.167051 \n", "3 551 -0.851273 0.733108 -0.843535 -0.786363 -0.049836 -0.424532 \n", "4 199 0.091654 0.216499 0.103839 -0.034667 0.167930 0.308132 \n", ".. ... ... ... ... ... ... ... \n", "564 529 -0.584300 -1.361252 -0.582390 -0.596377 0.970677 -0.270077 \n", "565 40 -0.195201 0.532980 -0.238451 -0.261342 -1.048999 -0.834452 \n", "566 115 -0.624062 0.521345 -0.635937 -0.615148 0.093918 -0.489914 \n", "567 2 1.579888 0.456187 1.566503 1.558884 0.942210 1.052926 \n", "568 39 -0.183840 0.356123 -0.147009 -0.272150 0.372887 0.400995 \n", "\n", " x6 x7 x8 ... x10 x11 x12 \\\n", "0 -0.219076 0.268537 0.015996 ... -0.337360 -0.728193 -0.442587 \n", "1 -0.976392 -0.898898 0.983496 ... -0.493639 0.348620 -0.552483 \n", "2 -1.114873 -1.261820 -0.327193 ... -0.666881 -0.779358 -0.708418 \n", "3 -0.509221 -0.679649 0.797298 ... -0.451772 0.453852 -0.431696 \n", "4 0.366614 0.280661 0.505223 ... -0.707304 -1.026834 -0.702973 \n", ".. ... ... ... ... ... ... ... \n", "564 -0.640169 -0.540104 -0.564504 ... -0.555357 -1.293361 -0.570305 \n", "565 -0.724413 -0.737944 -0.100834 ... -0.601554 -0.708235 -0.640599 \n", "566 -0.697043 -0.743876 -0.451325 ... -0.336999 -0.533695 -0.428726 \n", "567 1.363478 2.037231 0.939685 ... 1.228676 -0.780083 0.850928 \n", "568 0.219721 0.141115 -0.334494 ... -0.693589 -1.134788 -0.653965 \n", "\n", " x13 x14 x15 x16 x17 x18 x19 \n", "0 -0.272757 -0.608018 -0.577235 -0.501126 0.143371 -0.466431 -0.554102 \n", "1 -0.526877 2.253098 -0.827620 -0.780739 -0.376997 -0.310239 0.176301 \n", "2 -0.637545 0.710369 -0.976454 -1.057501 -1.913447 0.795207 -0.149751 \n", "3 -0.494754 -1.182041 0.281228 0.084759 -0.252420 1.038575 0.351054 \n", "4 -0.460212 -0.999033 -0.531406 -0.394360 -0.728830 -0.644416 -0.688003 \n", ".. ... ... ... ... ... ... ... \n", "564 -0.479573 0.095344 -0.779555 -0.461337 -0.618041 -0.111671 -0.590414 \n", "565 -0.435790 -1.253710 -0.808059 -0.596618 -0.797283 -0.816347 -0.948996 \n", "566 -0.342062 0.254017 -0.022811 -0.449069 -0.662649 -0.939848 0.023110 \n", "567 1.181336 -0.297005 0.814974 0.213076 1.424827 0.237036 0.293559 \n", "568 -0.480013 -0.558016 -0.172595 -0.046543 0.133639 -0.819980 -0.229940 \n", "\n", "[569 rows x 21 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "df = pd.read_csv('../../../../examples/data/breast_hetero_host.csv')\n", "df" ] }, { "attachments": {}, "cell_type": "markdown", "id": "858d4578", "metadata": {}, "source": [ "## Write the Pipeline script and execute it\n", "\n", "After the upload is complete, we can start writing the pipeline script to submit a FATE task." ] }, { "cell_type": "code", "execution_count": 14, "id": "3acf33d1", "metadata": {}, "outputs": [], "source": [ "import torch as t\n", "from torch import nn\n", "from pipeline.backend.pipeline import PipeLine # pipeline Class\n", "from pipeline import fate_torch_hook\n", "from pipeline.component import HeteroNN, Reader, DataTransform, Intersection # Hetero NN Component, Data IO component, PSI component\n", "from pipeline.interface import Data, Model # data, model for defining the work flow" ] }, { "attachments": {}, "cell_type": "markdown", "id": "48edc92d", "metadata": {}, "source": [ "### fate_torch_hook\n", "\n", "Please be sure to execute the following fate_torch_hook function, which can modify some classes of torch, so that the torch layers, sequential, optimizer, and loss function you define in the scripts can be parsed and submitted by the pipeline. " ] }, { "cell_type": "code", "execution_count": 15, "id": "955db238", "metadata": {}, "outputs": [], "source": [ "from pipeline import fate_torch_hook\n", "t = fate_torch_hook(t)" ] }, { "cell_type": "code", "execution_count": 16, "id": "38706fcd", "metadata": {}, "outputs": [], "source": [ "\n", "guest = 9999\n", "host = 10000\n", "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)\n", "\n", "guest_train_data = {\"name\": \"breast_hetero_guest\", \"namespace\": \"experiment\"}\n", "host_train_data = {\"name\": \"breast_hetero_host\", \"namespace\": \"experiment\"}\n", "\n", "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)\n", "\n", "# read uploaded dataset\n", "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=guest_train_data)\n", "reader_0.get_party_instance(role='host', party_id=host).component_param(table=host_train_data)\n", "# The transform component converts the uploaded data to the DATE standard format\n", "data_transform_0 = DataTransform(name=\"data_transform_0\")\n", "data_transform_0.get_party_instance(role='guest', party_id=guest).component_param(with_label=True)\n", "data_transform_0.get_party_instance(role='host', party_id=host).component_param(with_label=False)\n", "# intersection\n", "intersection_0 = Intersection(name=\"intersection_0\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f9d890fb", "metadata": {}, "source": [ "### The Hetero NN Component\n", "\n", "Here we initialize the Hetero-NN component. We use get_party_instance to obtain the guest component and host component respectively. As the model architectures of the two parties differ, we must specify the model parameters for each party using the respective components." ] }, { "cell_type": "code", "execution_count": 17, "id": "0b0c0ae4", "metadata": {}, "outputs": [], "source": [ "hetero_nn_0 = HeteroNN(name=\"hetero_nn_0\", epochs=2,\n", " interactive_layer_lr=0.01, batch_size=-1, validation_freqs=1, task_type='classification', seed=114514)\n", "guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)\n", "host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "982bdd56", "metadata": {}, "source": [ "### Defining Guest & Host Model" ] }, { "cell_type": "code", "execution_count": 18, "id": "3004a3a6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "InteractiveLayer(\n", " (activation): ReLU()\n", " (guest_model): Linear(in_features=2, out_features=2, bias=True)\n", " (host_model): ModuleList(\n", " (0): Linear(in_features=2, out_features=2, bias=True)\n", " )\n", " (act_seq): Sequential(\n", " (0): ReLU()\n", " )\n", ")\n" ] } ], "source": [ "# Guest Bottom, Top Model\n", "guest_bottom = t.nn.Sequential(\n", " nn.Linear(10, 2),\n", " nn.ReLU()\n", ")\n", "guest_top = t.nn.Sequential(\n", " nn.Linear(2, 1),\n", " nn.Sigmoid()\n", ")\n", "\n", "# Host Bottom Model\n", "host_bottom = t.nn.Sequential(\n", " nn.Linear(20, 2),\n", " nn.ReLU()\n", ")\n", "\n", "# After using fate_torch_hook, nn module can use InteractiveLayer, you can view the structure of Interactive layer with print\n", "interactive_layer = t.nn.InteractiveLayer(out_dim=2, guest_dim=2, host_dim=2, host_num=1)\n", "print(interactive_layer)\n", "\n", "guest_nn_0.add_top_model(guest_top)\n", "guest_nn_0.add_bottom_model(guest_bottom)\n", "host_nn_0.add_bottom_model(host_bottom)\n", "\n", "optimizer = t.optim.Adam(lr=0.01) # Notice! After fate_torch_hook, the optimizer can be initialized without model parameter\n", "loss = t.nn.BCELoss()\n", "\n", "hetero_nn_0.set_interactive_layer(interactive_layer)\n", "hetero_nn_0.compile(optimizer=optimizer, loss=loss)" ] }, { "cell_type": "code", "execution_count": 19, "id": "5e683308", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.add_component(reader_0)\n", "pipeline.add_component(data_transform_0, data=Data(data=reader_0.output.data))\n", "pipeline.add_component(intersection_0, data=Data(data=data_transform_0.output.data))\n", "pipeline.add_component(hetero_nn_0, data=Data(train_data=intersection_0.output.data))\n", "pipeline.compile()" ] }, { "cell_type": "code", "execution_count": 20, "id": "1c6f7199", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 11:59:51.084\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191159500270390\n", "\u001b[0m\n", "\u001b[32m2022-12-19 11:59:51.107\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[32m2022-12-19 11:59:52.127\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n", "\u001b[0mm2022-12-19 11:59:53.155\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 11:59:53.157\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:02\u001b[0m\n", "\u001b[32m2022-12-19 11:59:54.195\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 11:59:55.255\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 11:59:56.289\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 11:59:57.348\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 11:59:58.424\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:07\u001b[0m\n", "\u001b[32m2022-12-19 11:59:59.470\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:08\u001b[0m\n", "\u001b[32m2022-12-19 12:00:00.521\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:09\u001b[0m\n", "\u001b[32m2022-12-19 12:00:01.598\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:10\u001b[0m\n", "\u001b[0mm2022-12-19 12:00:02.740\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 12:00:02.743\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:11\u001b[0m\n", "\u001b[32m2022-12-19 12:00:03.781\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:12\u001b[0m\n", "\u001b[32m2022-12-19 12:00:04.819\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:13\u001b[0m\n", "\u001b[32m2022-12-19 12:00:05.852\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:14\u001b[0m\n", "\u001b[32m2022-12-19 12:00:06.886\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:15\u001b[0m\n", "\u001b[32m2022-12-19 12:00:07.922\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:16\u001b[0m\n", "\u001b[32m2022-12-19 12:00:08.995\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:17\u001b[0m\n", "\u001b[32m2022-12-19 12:00:10.031\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:18\u001b[0m\n", "\u001b[32m2022-12-19 12:00:11.065\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:19\u001b[0m\n", "\u001b[32m2022-12-19 12:00:12.187\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:21\u001b[0m\n", "\u001b[32m2022-12-19 12:00:13.225\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component data_transform_0, time elapse: 0:00:22\u001b[0m\n", "\u001b[0mm2022-12-19 12:00:14.299\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 12:00:14.304\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:23\u001b[0m\n", "\u001b[32m2022-12-19 12:00:15.347\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:24\u001b[0m\n", "\u001b[32m2022-12-19 12:00:16.381\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:25\u001b[0m\n", "\u001b[32m2022-12-19 12:00:17.419\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:26\u001b[0m\n", "\u001b[32m2022-12-19 12:00:18.443\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:27\u001b[0m\n", "\u001b[32m2022-12-19 12:00:19.476\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:28\u001b[0m\n", "\u001b[32m2022-12-19 12:00:20.504\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:29\u001b[0m\n", "\u001b[32m2022-12-19 12:00:21.570\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:30\u001b[0m\n", "\u001b[32m2022-12-19 12:00:22.600\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:31\u001b[0m\n", "\u001b[32m2022-12-19 12:00:23.634\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:32\u001b[0m\n", "\u001b[32m2022-12-19 12:00:24.680\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:33\u001b[0m\n", "\u001b[32m2022-12-19 12:00:25.720\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:34\u001b[0m\n", "\u001b[32m2022-12-19 12:00:26.811\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:35\u001b[0m\n", "\u001b[32m2022-12-19 12:00:27.877\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component intersection_0, time elapse: 0:00:36\u001b[0m\n", "\u001b[0mm2022-12-19 12:00:28.975\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 12:00:28.980\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:37\u001b[0m\n", "\u001b[32m2022-12-19 12:00:30.034\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:38\u001b[0m\n", "\u001b[32m2022-12-19 12:00:31.062\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:39\u001b[0m\n", "\u001b[32m2022-12-19 12:00:32.103\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:41\u001b[0m\n", "\u001b[32m2022-12-19 12:00:33.140\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:42\u001b[0m\n", "\u001b[32m2022-12-19 12:00:34.257\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:43\u001b[0m\n", "\u001b[32m2022-12-19 12:00:35.331\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:44\u001b[0m\n", "\u001b[32m2022-12-19 12:00:36.362\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:45\u001b[0m\n", "\u001b[32m2022-12-19 12:00:37.437\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:46\u001b[0m\n", "\u001b[32m2022-12-19 12:00:38.460\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:47\u001b[0m\n", "\u001b[32m2022-12-19 12:00:39.487\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:48\u001b[0m\n", "\u001b[32m2022-12-19 12:00:40.526\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:49\u001b[0m\n", "\u001b[32m2022-12-19 12:00:41.556\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:50\u001b[0m\n", "\u001b[32m2022-12-19 12:00:42.602\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:51\u001b[0m\n", "\u001b[32m2022-12-19 12:00:43.628\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:52\u001b[0m\n", "\u001b[32m2022-12-19 12:00:44.652\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:53\u001b[0m\n", "\u001b[32m2022-12-19 12:00:45.687\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:54\u001b[0m\n", "\u001b[32m2022-12-19 12:00:46.715\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:55\u001b[0m\n", "\u001b[32m2022-12-19 12:00:47.742\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:56\u001b[0m\n", "\u001b[32m2022-12-19 12:00:48.769\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:57\u001b[0m\n", "\u001b[32m2022-12-19 12:00:49.792\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:58\u001b[0m\n", "\u001b[32m2022-12-19 12:00:50.852\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:00:59\u001b[0m\n", "\u001b[32m2022-12-19 12:00:51.880\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:01:00\u001b[0m\n", "\u001b[32m2022-12-19 12:00:52.914\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:01:01\u001b[0m\n", "\u001b[32m2022-12-19 12:00:53.936\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:01:02\u001b[0m\n", "\u001b[32m2022-12-19 12:00:54.998\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:01:03\u001b[0m\n", "\u001b[32m2022-12-19 12:00:56.084\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:01:05\u001b[0m\n", "\u001b[32m2022-12-19 12:00:57.117\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component hetero_nn_0, time elapse: 0:01:06\u001b[0m\n", "\u001b[32m2022-12-19 12:00:59.208\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191159500270390\u001b[0m\n", "\u001b[32m2022-12-19 12:00:59.210\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:01:08\u001b[0m\n" ] } ], "source": [ "pipeline.fit()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "8902124a", "metadata": {}, "source": [ "## Get Component Output" ] }, { "cell_type": "code", "execution_count": 21, "id": "4f7325c1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idlabelpredict_resultpredict_scorepredict_detailtype
000.000.13979218900203705{'0': 0.860207810997963, '1': 0.13979218900203...train
110.000.19935783743858337{'0': 0.8006421625614166, '1': 0.1993578374385...train
220.000.2489972561597824{'0': 0.7510027438402176, '1': 0.2489972561597...train
330.000.25491416454315186{'0': 0.7450858354568481, '1': 0.2549141645431...train
441.000.2584167718887329{'0': 0.7415832281112671, '1': 0.2584167718887...train
.....................
5645640.000.19034752249717712{'0': 0.8096524775028229, '1': 0.1903475224971...train
5655651.000.261306494474411{'0': 0.738693505525589, '1': 0.261306494474411}train
5665661.000.26077690720558167{'0': 0.7392230927944183, '1': 0.2607769072055...train
5675671.000.2625167667865753{'0': 0.7374832332134247, '1': 0.2625167667865...train
5685680.000.24826040863990784{'0': 0.7517395913600922, '1': 0.2482604086399...train
\n", "

569 rows × 6 columns

\n", "
" ], "text/plain": [ " id label predict_result predict_score \\\n", "0 0 0.0 0 0.13979218900203705 \n", "1 1 0.0 0 0.19935783743858337 \n", "2 2 0.0 0 0.2489972561597824 \n", "3 3 0.0 0 0.25491416454315186 \n", "4 4 1.0 0 0.2584167718887329 \n", ".. ... ... ... ... \n", "564 564 0.0 0 0.19034752249717712 \n", "565 565 1.0 0 0.261306494474411 \n", "566 566 1.0 0 0.26077690720558167 \n", "567 567 1.0 0 0.2625167667865753 \n", "568 568 0.0 0 0.24826040863990784 \n", "\n", " predict_detail type \n", "0 {'0': 0.860207810997963, '1': 0.13979218900203... train \n", "1 {'0': 0.8006421625614166, '1': 0.1993578374385... train \n", "2 {'0': 0.7510027438402176, '1': 0.2489972561597... train \n", "3 {'0': 0.7450858354568481, '1': 0.2549141645431... train \n", "4 {'0': 0.7415832281112671, '1': 0.2584167718887... train \n", ".. ... ... \n", "564 {'0': 0.8096524775028229, '1': 0.1903475224971... train \n", "565 {'0': 0.738693505525589, '1': 0.261306494474411} train \n", "566 {'0': 0.7392230927944183, '1': 0.2607769072055... train \n", "567 {'0': 0.7374832332134247, '1': 0.2625167667865... train \n", "568 {'0': 0.7517395913600922, '1': 0.2482604086399... train \n", "\n", "[569 rows x 6 columns]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# get predict scores\n", "pipeline.get_component('hetero_nn_0').get_output_data()" ] }, { "cell_type": "code", "execution_count": 23, "id": "0b1c79c1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'best_iteration': -1,\n", " 'history_loss': [0.9929580092430115, 0.9658427238464355],\n", " 'is_converged': False,\n", " 'validation_metrics': {'train': {'auc': [0.8850615717985308,\n", " 0.9316368056656624],\n", " 'ks': [0.6326568363194334, 0.7479123724961683]}}}" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# get summary\n", "pipeline.get_component('hetero_nn_0').get_summary()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "163272f1", "metadata": {}, "source": [ "So far, we have gained a basic understanding of Hetero-NN and have utilized it to perform basic modeling tasks. Hetero-NN also supports the use of more complex models, datasets. For further information, refer to the additional tutorials provided." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "vscode": { "interpreter": { "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853" } } }, "nbformat": 4, "nbformat_minor": 5 }