{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "7a982d9b", "metadata": {}, "source": [ "# Customize loss function\n", "\n", "When Pytorch's built-in Loss function cannot meet your usage needs, you can use custom Loss to train your model\n", "\n", "## A little problem with the MNIST example\n", "\n", "You might notice that in the MNIST example in last tutorial [Customize your Dataset](Homo-NN-Customize-your-Dataset.ipynb), the classifier output scores are the result of the Softmax function, and we are using torch built-in CrossEntropyLoss to compute the loss. However, it shows in documentation ([CrossEntropyLoss Doc](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss)) that the input is expected to contain the unnormalized logits for each class, that is to say, in that example, we compute Softmax twice.\n", "To tackle this problem, we can use a customized CrossEntropyLoss. " ] }, { "attachments": {}, "cell_type": "markdown", "id": "40b31519", "metadata": {}, "source": [ "## Develop a Custom loss\n", "\n", "A Customized Loss is a class that subclass the torch.nn.Module and implements the forward function. In the FATE trainer, the loss function will be passed two parameters: the predicted scores and the label (loss_fn(pred, loss)), so when you are using FATE's trainer, your loss funcion need to take two parameters as input(predict score & label). However, if you are using your own trainer and have defined your own training process, you are not restricted in how you use the loss function." ] }, { "attachments": {}, "cell_type": "markdown", "id": "054985e4", "metadata": {}, "source": [ "### A New CrossEntropy Loss\n", "\n", "Here we realize a new CrossEntropyLoss that skips softmax computation. We can use the jupyter interface: save_to_fate, to update the code to federatedml.nn.loss, named ce.py, of course you can manually copy the code file to the directory." ] }, { "cell_type": "code", "execution_count": 20, "id": "808626e6", "metadata": {}, "outputs": [], "source": [ "import torch as t\n", "from federatedml.util import consts\n", "from torch.nn.functional import one_hot\n", "\n", "\n", "def cross_entropy(p2, p1, reduction='mean'):\n", " p2 = p2 + consts.FLOAT_ZERO # to avoid nan\n", " assert p2.shape == p1.shape\n", " if reduction == 'sum':\n", " return -t.sum(p1 * t.log(p2))\n", " elif reduction == 'mean':\n", " return -t.mean(t.sum(p1 * t.log(p2), dim=1))\n", " elif reduction == 'none':\n", " return -t.sum(p1 * t.log(p2), dim=1)\n", " else:\n", " raise ValueError('unknown reduction')\n", "\n", "\n", "class CrossEntropyLoss(t.nn.Module):\n", "\n", " \"\"\"\n", " A CrossEntropy Loss that will not compute Softmax\n", " \"\"\"\n", "\n", " def __init__(self, reduction='mean'):\n", " super(CrossEntropyLoss, self).__init__()\n", " self.reduction = reduction\n", "\n", " def forward(self, pred, label):\n", "\n", " one_hot_label = one_hot(label.flatten())\n", " loss_ = cross_entropy(pred, one_hot_label, self.reduction)\n", "\n", " return loss_" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2122f107", "metadata": {}, "source": [ "## Train with New Loss" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d258b9d2", "metadata": {}, "source": [ "### Import Components" ] }, { "cell_type": "code", "execution_count": 21, "id": "1518af62", "metadata": {}, "outputs": [], "source": [ "import torch as t\n", "from torch import nn\n", "from pipeline import fate_torch_hook\n", "from pipeline.component import HomoNN\n", "from pipeline.backend.pipeline import PipeLine\n", "from pipeline.component import Reader, Evaluation, DataTransform\n", "from pipeline.interface import Data, Model\n", "\n", "t = fate_torch_hook(t)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "8315687c", "metadata": {}, "source": [ "### Bind data path to name & namespace" ] }, { "cell_type": "code", "execution_count": 22, "id": "d900c35a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'namespace': 'experiment', 'table_name': 'mnist_host'}" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "# bind data path to name & namespace\n", "fate_project_path = os.path.abspath('../../../../')\n", "arbiter = 10000\n", "host = 10000\n", "guest = 9999\n", "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,\n", " arbiter=arbiter)\n", "\n", "data_0 = {\"name\": \"mnist_guest\", \"namespace\": \"experiment\"}\n", "data_1 = {\"name\": \"mnist_host\", \"namespace\": \"experiment\"}\n", "\n", "data_path_0 = fate_project_path + '/examples/data/mnist'\n", "data_path_1 = fate_project_path + '/examples/data/mnist'\n", "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)\n", "pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)" ] }, { "cell_type": "code", "execution_count": 23, "id": "d3af79ff", "metadata": {}, "outputs": [], "source": [ "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)\n", "reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "86d4085a", "metadata": {}, "source": [ "## Use CustLoss\n", "\n", "After fate_torch_hook, we can use t.nn.CustLoss to specify your own loss. We will specify the module name and the class name in the parameter, and behind is the initialization parameter for your loss class. **The initialization parameter must be JSON-serializable, otherwise this pipeline can not be submitted.**" ] }, { "cell_type": "code", "execution_count": 24, "id": "de9917a7", "metadata": {}, "outputs": [], "source": [ "from pipeline.component.homo_nn import TrainerParam, DatasetParam # Interface\n", "\n", "# your loss class\n", "loss = t.nn.CustLoss(loss_module_name='cross_entropy', class_name='CrossEntropyLoss', reduction='mean')\n", "\n", "# our simple classification model:\n", "model = t.nn.Sequential(\n", " t.nn.Linear(784, 32),\n", " t.nn.ReLU(),\n", " t.nn.Linear(32, 10),\n", " t.nn.Softmax(dim=1)\n", ")\n", "\n", "nn_component = HomoNN(name='nn_0',\n", " model=model, # model\n", " loss=loss, # loss\n", " optimizer=t.optim.Adam(model.parameters(), lr=0.01), # optimizer\n", " dataset=DatasetParam(dataset_name='mnist_dataset', flatten_feature=True), # dataset\n", " trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=1024, validation_freqs=1),\n", " torch_seed=100 # random seed\n", " )" ] }, { "cell_type": "code", "execution_count": 25, "id": "62361f34", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.add_component(reader_0)\n", "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))\n", "pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))" ] }, { "cell_type": "code", "execution_count": 26, "id": "1fa46219", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 18:39:12.858\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191839119838210\n", "\u001b[0m\n", "\u001b[32m2022-12-19 18:39:12.890\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[0mm2022-12-19 18:39:13.940\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 18:39:13.943\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:01\u001b[0m\n", "\u001b[32m2022-12-19 18:39:14.977\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:02\u001b[0m\n", "\u001b[32m2022-12-19 18:39:16.036\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 18:39:17.088\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 18:39:18.133\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 18:39:19.184\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 18:39:20.246\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:07\u001b[0m\n", "\u001b[32m2022-12-19 18:39:21.278\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:08\u001b[0m\n", "\u001b[32m2022-12-19 18:39:22.319\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:09\u001b[0m\n", "\u001b[32m2022-12-19 18:39:23.343\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:10\u001b[0m\n", "\u001b[32m2022-12-19 18:39:24.383\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:11\u001b[0m\n", "\u001b[0mm2022-12-19 18:39:26.565\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 18:39:26.568\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:13\u001b[0m\n", "\u001b[32m2022-12-19 18:39:27.611\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:14\u001b[0m\n", "\u001b[32m2022-12-19 18:39:28.656\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:15\u001b[0m\n", "\u001b[32m2022-12-19 18:39:29.713\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:16\u001b[0m\n", "\u001b[32m2022-12-19 18:39:30.774\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:17\u001b[0m\n", "\u001b[32m2022-12-19 18:39:31.812\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:18\u001b[0m\n", "\u001b[32m2022-12-19 18:39:32.857\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:19\u001b[0m\n", "\u001b[32m2022-12-19 18:39:33.981\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:21\u001b[0m\n", "\u001b[32m2022-12-19 18:39:35.004\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:22\u001b[0m\n", "\u001b[32m2022-12-19 18:39:36.092\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:23\u001b[0m\n", "\u001b[32m2022-12-19 18:39:37.129\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:24\u001b[0m\n", "\u001b[32m2022-12-19 18:39:38.166\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:25\u001b[0m\n", "\u001b[32m2022-12-19 18:39:39.244\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:26\u001b[0m\n", "\u001b[32m2022-12-19 18:39:40.286\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:27\u001b[0m\n", "\u001b[32m2022-12-19 18:39:41.429\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:28\u001b[0m\n", "\u001b[32m2022-12-19 18:39:42.479\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:29\u001b[0m\n", "\u001b[32m2022-12-19 18:39:43.621\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:30\u001b[0m\n", "\u001b[32m2022-12-19 18:39:44.665\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:31\u001b[0m\n", "\u001b[32m2022-12-19 18:39:45.717\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:32\u001b[0m\n", "\u001b[32m2022-12-19 18:39:46.758\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:33\u001b[0m\n", "\u001b[32m2022-12-19 18:39:47.802\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:34\u001b[0m\n", "\u001b[32m2022-12-19 18:39:48.847\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:35\u001b[0m\n", "\u001b[32m2022-12-19 18:39:49.895\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:37\u001b[0m\n", "\u001b[32m2022-12-19 18:39:50.946\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:38\u001b[0m\n", "\u001b[0mm2022-12-19 18:39:53.243\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 18:39:53.246\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:40\u001b[0m\n", "\u001b[32m2022-12-19 18:39:54.538\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:41\u001b[0m\n", "\u001b[32m2022-12-19 18:39:55.640\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:42\u001b[0m\n", "\u001b[32m2022-12-19 18:39:56.688\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:43\u001b[0m\n", "\u001b[32m2022-12-19 18:39:57.779\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:44\u001b[0m\n", "\u001b[32m2022-12-19 18:39:58.820\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:45\u001b[0m\n", "\u001b[32m2022-12-19 18:40:00.137\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:47\u001b[0m\n", "\u001b[32m2022-12-19 18:40:01.182\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:48\u001b[0m\n", "\u001b[32m2022-12-19 18:40:02.214\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:49\u001b[0m\n", "\u001b[32m2022-12-19 18:40:03.277\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:50\u001b[0m\n", "\u001b[32m2022-12-19 18:40:04.307\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:51\u001b[0m\n", "\u001b[32m2022-12-19 18:40:05.342\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:52\u001b[0m\n", "\u001b[32m2022-12-19 18:40:06.416\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:53\u001b[0m\n", "\u001b[32m2022-12-19 18:40:07.456\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:54\u001b[0m\n", "\u001b[32m2022-12-19 18:40:10.543\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191839119838210\u001b[0m\n", "\u001b[32m2022-12-19 18:40:10.545\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:00:57\u001b[0m\n" ] } ], "source": [ "pipeline.compile()\n", "pipeline.fit()" ] }, { "cell_type": "code", "execution_count": 40, "id": "0edf9014", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idlabelpredict_resultpredict_scorepredict_detailtype
0img_1000.9070178270339966{'0': 0.9070178270339966, '1': 0.0023874549660...train
1img_3460.19601570069789886{'0': 0.19484134018421173, '1': 0.044997252523...train
2img_4000.9618675112724304{'0': 0.9618675112724304, '1': 0.0010393995326...train
3img_5000.33044907450675964{'0': 0.33044907450675964, '1': 0.033256266266...train
4img_6770.3145765960216522{'0': 0.05851678550243378, '1': 0.075524508953...train
.....................
1304img_32537180.20599651336669922{'0': 0.080563984811306, '1': 0.12380836158990...train
1305img_32558180.20311488211154938{'0': 0.07224143296480179, '1': 0.130610913038...train
1306img_32563180.2071550488471985{'0': 0.06843454390764236, '1': 0.129064396023...train
1307img_32565150.29367145895957947{'0': 0.05658009275794029, '1': 0.086584843695...train
1308img_32573180.199515700340271{'0': 0.08787216246128082, '1': 0.127247273921...train
\n", "

1309 rows × 6 columns

\n", "
" ], "text/plain": [ " id label predict_result predict_score \\\n", "0 img_1 0 0 0.9070178270339966 \n", "1 img_3 4 6 0.19601570069789886 \n", "2 img_4 0 0 0.9618675112724304 \n", "3 img_5 0 0 0.33044907450675964 \n", "4 img_6 7 7 0.3145765960216522 \n", "... ... ... ... ... \n", "1304 img_32537 1 8 0.20599651336669922 \n", "1305 img_32558 1 8 0.20311488211154938 \n", "1306 img_32563 1 8 0.2071550488471985 \n", "1307 img_32565 1 5 0.29367145895957947 \n", "1308 img_32573 1 8 0.199515700340271 \n", "\n", " predict_detail type \n", "0 {'0': 0.9070178270339966, '1': 0.0023874549660... train \n", "1 {'0': 0.19484134018421173, '1': 0.044997252523... train \n", "2 {'0': 0.9618675112724304, '1': 0.0010393995326... train \n", "3 {'0': 0.33044907450675964, '1': 0.033256266266... train \n", "4 {'0': 0.05851678550243378, '1': 0.075524508953... train \n", "... ... ... \n", "1304 {'0': 0.080563984811306, '1': 0.12380836158990... train \n", "1305 {'0': 0.07224143296480179, '1': 0.130610913038... train \n", "1306 {'0': 0.06843454390764236, '1': 0.129064396023... train \n", "1307 {'0': 0.05658009275794029, '1': 0.086584843695... train \n", "1308 {'0': 0.08787216246128082, '1': 0.127247273921... train \n", "\n", "[1309 rows x 6 columns]" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.get_component('nn_0').get_output_data()" ] }, { "cell_type": "code", "execution_count": 41, "id": "8592212b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'best_epoch': 1,\n", " 'loss_history': [3.58235876026547, 3.4448592824914055],\n", " 'metrics_summary': {'train': {'accuracy': [0.25668449197860965,\n", " 0.4950343773873186],\n", " 'precision': [0.3708616690797323, 0.5928620913124757],\n", " 'recall': [0.21817632850241547, 0.4855654369784805]}},\n", " 'need_stop': False}" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.get_component('nn_0').get_summary()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "vscode": { "interpreter": { "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853" } } }, "nbformat": 4, "nbformat_minor": 5 }