{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "7a982d9b", "metadata": {}, "source": [ "# Homo-NN: Customize your Dataset" ] }, { "cell_type": "markdown", "id": "40b31519", "metadata": {}, "source": [ "The FATE system primarily supports tabular data as its standard data format. However, it is possible to utilize non-tabular data, such as images, text, mixed data, or relational data, in neural networks through the use of the Dataset feature of the NN module. The Dataset module within the NN module allows for the customization of datasets for use in more complex data scenarios. This tutorial will cover the use of the Dataset feature in the Homo-NN module and provide guidance on how to customize datasets. We will use the MNIST handwriting recognition task as an example to illustrate these concepts." ] }, { "cell_type": "markdown", "id": "2122f107", "metadata": {}, "source": [ "## Prepare MNIST Data\n", "\n", "Please download the MNIST dataset from the link below and place it in the project examples/data folder:\n", "https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/mnist.zip\n", "\n", "This is a simplified version of the MNIST dataset, with a total of ten categories, which are classified into 0-9 10 folders according to labels. We sampled the dataset to reduce the sample number.\n", "\n", "The origin of MNIST dataset is:\n", "http://yann.lecun.com/exdb/mnist/" ] }, { "cell_type": "code", "execution_count": 5, "id": "87b15585", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 1 2 3 4 5 6 7\t8 9\n" ] } ], "source": [ "! ls ../../../../examples/data/mnist" ] }, { "cell_type": "markdown", "id": "e255e886", "metadata": {}, "source": [ "## Dataset" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a2cc655b", "metadata": {}, "source": [ "In version FATE-1.10, FATE introduces a new base class for datasets called [Dataset](../../../../python/federatedml/nn/dataset/base.py), which is based on PyTorch's Dataset class. This class allows users to create custom datasets according to their specific needs. The usage is similar to that of PyTorch's Dataset class, with the added requirement of implementing two additional interfaces when using FATE-NN for data reading and training: load() and get_sample_ids().\n", "\n", "To create a custom dataset in Homo-NN, users need to:\n", "\n", "- Develop a new dataset class that inherits from the Dataset class\n", "- Implement the \\_\\_len\\_\\_() and \\_\\_getitem\\_\\_() methods, which are consistent with PyTorch's Dataset usage. The \\_\\_len\\_\\_() method should return the length of the dataset, while the \\_\\_getitem\\_\\_() method should return the corresponding data at the specified index\n", "- Implement the load() and get_sample_ids() methods\n", " \n", "For those unfamiliar with PyTorch's Dataset class, more information can be found in the PyTorch documentation: [Pytorch Dataset Documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)" ] }, { "cell_type": "markdown", "id": "028a993a", "metadata": {}, "source": [ "### load()\n", "\n", "The first additional interface required is load(). This interface receives a file path and allows users to read data directly from the local file system. When submitting a task, the data path can be specified through the reader component. Homo-NN will use the user-specified Dataset class, utilizing the load() interface to read data from the specified path and complete the loading of the dataset for training. For more information, please refer to the source code in /federatedml/nn/dataset/base.py.\n", "\n", "### get_sample_ids()\n", "\n", "The second additional interface is get_sample_ids(). This interface should return a list of sample IDs, which can be either integers or strings and should have the same length as the dataset. Actually you can skip implementing this interface when using Homo-NN, as the Homo-NN component will automatically generate IDs for the samples." ] }, { "cell_type": "markdown", "id": "77f97084", "metadata": {}, "source": [ "## Example: Implement a simple image dataset\n", "\n", "In order to better understand the customization of Dataset, here we implement a simple image dataset to read MNIST images, and then complete a federated image classification task in a horizontal scene\n", "For convenience here, we use the jupyter interface of save_to_fate to update the code to federatedml.nn.dataset, named mnist_dataset.py, of course you can manually copy the code file to the directory" ] }, { "cell_type": "markdown", "id": "f972ba9f", "metadata": {}, "source": [ "### jupyter: save_to_fate()" ] }, { "cell_type": "code", "execution_count": 1, "id": "0038d15f", "metadata": {}, "outputs": [], "source": [ "from pipeline.component.nn import save_to_fate" ] }, { "cell_type": "markdown", "id": "94f205eb", "metadata": {}, "source": [ "### The MNIST Dataset\n", "\n", "Here we implement the Dataset, and save it using save_to_fate()." ] }, { "cell_type": "code", "execution_count": 8, "id": "7c5ce623", "metadata": {}, "outputs": [], "source": [ "%%save_to_fate dataset mnist_dataset.py\n", "import numpy as np\n", "from federatedml.nn.dataset.base import Dataset\n", "from torchvision.datasets import ImageFolder\n", "from torchvision import transforms\n", "\n", "\n", "class MNISTDataset(Dataset):\n", " \n", " def __init__(self, flatten_feature=False): # flatten feature or not \n", " super(MNISTDataset, self).__init__()\n", " self.image_folder = None\n", " self.ids = None\n", " self.flatten_feature = flatten_feature\n", " \n", " def load(self, path): # read data from path, and set sample ids\n", " # read using ImageFolder\n", " self.image_folder = ImageFolder(root=path, transform=transforms.Compose([transforms.ToTensor()]))\n", " # filename as the image id\n", " ids = []\n", " for image_name in self.image_folder.imgs:\n", " ids.append(image_name[0].split('/')[-1].replace('.jpg', ''))\n", " self.ids = ids\n", " return self\n", "\n", " def get_sample_ids(self): # implement the get sample id interface, simply return ids\n", " return self.ids\n", " \n", " def __len__(self,): # return the length of the dataset\n", " return len(self.image_folder)\n", " \n", " def __getitem__(self, idx): # get item\n", " ret = self.image_folder[idx]\n", " if self.flatten_feature:\n", " img = ret[0][0].flatten() # return flatten tensor 784-dim\n", " return img, ret[1] # return tensor and label\n", " else:\n", " return ret" ] }, { "cell_type": "markdown", "id": "b20dc3bb", "metadata": {}, "source": [ "After we implement the dataset, we can test it locally:" ] }, { "cell_type": "code", "execution_count": 9, "id": "e8ddcac0", "metadata": {}, "outputs": [], "source": [ "from federatedml.nn.dataset.mnist_dataset import MNISTDataset\n", "\n", "ds = MNISTDataset(flatten_feature=True)\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "44f5ee89", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1309\n", "(tensor([0.0118, 0.0000, 0.0000, 0.0118, 0.0275, 0.0118, 0.0000, 0.0118, 0.0000,\n", " 0.0431, 0.0000, 0.0000, 0.0118, 0.0000, 0.0000, 0.0118, 0.0314, 0.0000,\n", " 0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039,\n", " 0.0196, 0.0000, 0.0471, 0.0000, 0.0627, 0.0000, 0.0000, 0.0157, 0.0000,\n", " 0.0078, 0.0314, 0.0118, 0.0000, 0.0157, 0.0314, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0000, 0.0000, 0.0039,\n", " 0.0078, 0.0039, 0.0471, 0.0000, 0.0314, 0.0000, 0.0000, 0.0235, 0.0000,\n", " 0.0431, 0.0000, 0.0000, 0.0235, 0.0275, 0.0078, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118, 0.0000, 0.0000, 0.0078,\n", " 0.0118, 0.0000, 0.0000, 0.0000, 0.0471, 0.0000, 0.0000, 0.0902, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0431, 0.0118, 0.0000, 0.0000, 0.0157, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0039, 0.0000, 0.0000,\n", " 0.0078, 0.0000, 0.0000, 0.0235, 0.0000, 0.0980, 0.1059, 0.5333, 0.5294,\n", " 0.7373, 0.3490, 0.3294, 0.0980, 0.0000, 0.0000, 0.0118, 0.0039, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0157, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.3451, 0.9686, 0.9255, 1.0000,\n", " 0.9765, 0.9804, 0.8902, 0.9412, 0.5333, 0.1451, 0.0039, 0.0000, 0.0078,\n", " 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0000,\n", " 0.0118, 0.0000, 0.0000, 0.0157, 0.1059, 0.7569, 0.9843, 0.9922, 1.0000,\n", " 1.0000, 1.0000, 1.0000, 0.9412, 0.9961, 1.0000, 0.8353, 0.3490, 0.0000,\n", " 0.0000, 0.0549, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0235, 0.0000, 0.0000, 0.0706, 0.2196, 0.9647, 1.0000, 0.9922,\n", " 0.9529, 0.9843, 1.0000, 0.9608, 1.0000, 1.0000, 0.9961, 1.0000, 0.9059,\n", " 0.4667, 0.0275, 0.0000, 0.0196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0157,\n", " 0.0000, 0.0000, 0.0471, 0.0510, 0.0000, 0.2549, 0.7451, 0.9647, 1.0000,\n", " 1.0000, 0.9843, 1.0000, 0.4275, 0.3451, 0.7804, 1.0000, 0.9686, 0.9804,\n", " 1.0000, 0.9176, 0.3608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0392, 0.0039, 0.0000, 0.0000, 0.0706, 0.6392, 0.9725, 1.0000,\n", " 0.9216, 0.8471, 0.5882, 0.5020, 0.1765, 0.0235, 0.0314, 0.0863, 0.8314,\n", " 1.0000, 1.0000, 0.9882, 0.6745, 0.0000, 0.0588, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0039, 0.0157, 0.0196, 0.0000, 0.0000, 0.7333, 1.0000,\n", " 0.9961, 0.3686, 0.2235, 0.0275, 0.0039, 0.0000, 0.0235, 0.0000, 0.0000,\n", " 0.5451, 0.9490, 1.0000, 1.0000, 0.8549, 0.2431, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0196, 0.0078, 0.0000, 0.0000, 0.0431, 0.2196, 0.9882,\n", " 0.9216, 0.9922, 0.0784, 0.0196, 0.0078, 0.0196, 0.0039, 0.0000, 0.0039,\n", " 0.0078, 0.0000, 0.3804, 0.9765, 0.9725, 0.9765, 0.6510, 0.0314, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0000, 0.2745,\n", " 1.0000, 1.0000, 0.9608, 0.0980, 0.0392, 0.0000, 0.0000, 0.0039, 0.0000,\n", " 0.0157, 0.0392, 0.0000, 0.0392, 1.0000, 0.9647, 0.9804, 0.6078, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0275, 0.0471, 0.0000,\n", " 0.3412, 0.8863, 1.0000, 0.7216, 0.0000, 0.0118, 0.0000, 0.0392, 0.0196,\n", " 0.0000, 0.0000, 0.0000, 0.0353, 0.0000, 0.7176, 0.9843, 1.0000, 0.8706,\n", " 0.0588, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.0039, 0.0000,\n", " 0.0745, 0.9020, 1.0000, 0.9529, 1.0000, 0.1373, 0.0078, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0353, 0.0314, 0.0000, 0.0000, 0.2745, 0.9608, 0.9490,\n", " 1.0000, 0.0549, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0157, 0.0118,\n", " 0.0000, 0.0745, 0.9843, 0.9373, 1.0000, 0.9686, 0.1176, 0.0039, 0.0000,\n", " 0.0157, 0.0157, 0.0549, 0.0000, 0.0000, 0.0078, 0.0000, 0.1843, 1.0000,\n", " 1.0000, 0.9686, 0.0824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000,\n", " 0.0078, 0.0078, 0.0000, 0.6784, 0.9686, 0.9882, 0.9804, 0.1098, 0.0392,\n", " 0.0000, 0.0000, 0.0314, 0.0000, 0.0000, 0.0000, 0.0314, 0.0000, 0.2627,\n", " 0.9765, 1.0000, 1.0000, 0.0471, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0235, 0.0118, 0.0000, 0.3451, 1.0000, 0.9843, 1.0000, 0.7373,\n", " 0.0824, 0.0000, 0.0588, 0.0000, 0.0314, 0.0078, 0.0627, 0.0000, 0.1373,\n", " 0.7843, 0.9686, 0.9843, 0.5255, 0.0157, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0118, 0.0118, 0.0039, 0.0000, 0.0431, 0.8275, 0.9686, 0.9765,\n", " 1.0000, 0.7412, 0.2980, 0.0000, 0.0000, 0.0157, 0.0000, 0.0078, 0.0000,\n", " 0.6627, 1.0000, 1.0000, 0.9686, 0.1843, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0235, 0.0000, 0.0000, 0.0078, 0.0000, 0.2314, 0.8039,\n", " 1.0000, 0.9412, 1.0000, 0.7137, 0.1608, 0.2196, 0.1098, 0.1294, 0.1647,\n", " 0.9373, 0.9647, 0.9843, 0.9333, 0.6157, 0.0000, 0.0039, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0078, 0.0039, 0.0000, 0.0000, 0.0078, 0.0392, 0.0000,\n", " 0.4078, 0.9373, 1.0000, 0.9412, 1.0000, 0.9922, 0.9686, 0.9294, 1.0000,\n", " 1.0000, 0.9804, 1.0000, 0.9373, 1.0000, 0.3922, 0.0000, 0.0039, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0118, 0.0000, 0.0000, 0.0275,\n", " 0.0000, 0.0157, 0.4471, 1.0000, 1.0000, 1.0000, 1.0000, 0.9686, 0.9765,\n", " 0.9922, 0.9843, 0.9961, 0.9294, 0.9843, 0.3490, 0.0000, 0.0000, 0.0039,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.0000, 0.0000,\n", " 0.0039, 0.0510, 0.0000, 0.0549, 0.6549, 1.0000, 0.9647, 0.9922, 1.0000,\n", " 1.0000, 0.9961, 0.9490, 1.0000, 0.9569, 0.2392, 0.0000, 0.0745, 0.0000,\n", " 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0039, 0.0275, 0.0000,\n", " 0.0000, 0.0157, 0.0000, 0.0549, 0.0000, 0.1059, 0.2392, 0.5608, 1.0000,\n", " 1.0000, 0.9882, 1.0000, 0.5843, 0.0824, 0.0235, 0.0627, 0.0000, 0.0000,\n", " 0.0275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000]), 0)\n", "img_1\n" ] } ], "source": [ "# load MNIST data and check \n", "ds.load('../../../../examples/data/mnist/')\n", "print(len(ds))\n", "print(ds[0])\n", "print(ds.get_sample_ids()[0])" ] }, { "cell_type": "markdown", "id": "0311ed01", "metadata": {}, "source": [ "## Test Your Dataset\n", "\n", "Before submitting a task, it is possible to test locally. As we mentioned in [1.1 Homo-NN Quick Start: A Binary Classification Task](Homo-NN-Quick-Start.ipynb), in Homo-NN, FATE uses the fedavg_trainer by default. Custom datasets, models, and trainers can be used for local debugging to test if the program runs correctly. **Note that during local testing, all federation processes will be skipped and the model will not perform federated averaging.**" ] }, { "cell_type": "code", "execution_count": 11, "id": "c53366f3", "metadata": {}, "outputs": [], "source": [ "from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer\n", "trainer = FedAVGTrainer(epochs=3, batch_size=256, shuffle=True, data_loader_worker=8, pin_memory=False) # set parameter" ] }, { "cell_type": "code", "execution_count": 12, "id": "711ef7fa", "metadata": {}, "outputs": [], "source": [ "trainer.local_mode() # !! Be sure to enable local_mode to skip the federation process !!" ] }, { "cell_type": "code", "execution_count": 15, "id": "d5a64b68", "metadata": {}, "outputs": [], "source": [ "import torch as t\n", "from pipeline import fate_torch_hook\n", "fate_torch_hook(t)\n", "# our simple classification model:\n", "model = t.nn.Sequential(\n", " t.nn.Linear(784, 32),\n", " t.nn.ReLU(),\n", " t.nn.Linear(32, 10),\n", " t.nn.Softmax(dim=1)\n", ")\n", "\n", "trainer.set_model(model) # set model" ] }, { "cell_type": "code", "execution_count": 16, "id": "0d65f9b8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "epoch is 0\n", "100%|██████████| 6/6 [00:00<00:00, 11.28it/s]\n", "epoch loss is 2.5860611556412336\n", "epoch is 1\n", "100%|██████████| 6/6 [00:00<00:00, 12.40it/s]\n", "epoch loss is 2.2709667185411098\n", "epoch is 2\n", "100%|██████████| 6/6 [00:00<00:00, 11.20it/s]\n", "epoch loss is 2.0878872911469277\n" ] } ], "source": [ "optimizer = t.optim.Adam(model.parameters(), lr=0.01) # optimizer\n", "loss = t.nn.CrossEntropyLoss() # loss function\n", "trainer.train(train_set=ds, optimizer=optimizer, loss=loss) # use dataset we just developed" ] }, { "cell_type": "markdown", "id": "e08ed729", "metadata": {}, "source": [ "In the train() function of the Trainer, your dataset will be iterated using Pytorch DataLoader. \n", "The program can run correctly! Now we can submit a federated task." ] }, { "cell_type": "markdown", "id": "413aefa9", "metadata": {}, "source": [ "## Submit a task with your dataset" ] }, { "cell_type": "markdown", "id": "d258b9d2", "metadata": {}, "source": [ "### Import Components" ] }, { "cell_type": "code", "execution_count": 42, "id": "1518af62", "metadata": {}, "outputs": [], "source": [ "import torch as t\n", "from torch import nn\n", "from pipeline import fate_torch_hook\n", "from pipeline.component import HomoNN\n", "from pipeline.backend.pipeline import PipeLine\n", "from pipeline.component import Reader, Evaluation, DataTransform\n", "from pipeline.interface import Data, Model\n", "\n", "t = fate_torch_hook(t)\n" ] }, { "cell_type": "markdown", "id": "8315687c", "metadata": {}, "source": [ "### Bind data path to name & namespace\n", "\n", "Here, we use the pipeline to bind a path to a name&namespace. Then we can use the reader component to pass this path to the 'load' interface of the dataset.\n", "The trainer will get this dataset in the train(), and iterate it with a Pytorch Dataloader. **Please notice that in this tutorial we are using a standalone version, if you are using a cluster version, you need to bind the data with the corresponding name&namespace on each machine.**" ] }, { "cell_type": "code", "execution_count": 34, "id": "d900c35a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'namespace': 'experiment', 'table_name': 'mnist_host'}" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "# bind data path to name & namespace\n", "fate_project_path = os.path.abspath('../../../../')\n", "host = 10000\n", "guest = 9999\n", "arbiter = 10000\n", "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,\n", " arbiter=arbiter)\n", "\n", "data_0 = {\"name\": \"mnist_guest\", \"namespace\": \"experiment\"}\n", "data_1 = {\"name\": \"mnist_host\", \"namespace\": \"experiment\"}\n", "\n", "data_path_0 = fate_project_path + '/examples/data/mnist'\n", "data_path_1 = fate_project_path + '/examples/data/mnist'\n", "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)\n", "pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)" ] }, { "cell_type": "code", "execution_count": 35, "id": "d3af79ff", "metadata": {}, "outputs": [], "source": [ "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)\n", "reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)" ] }, { "cell_type": "markdown", "id": "673fe620", "metadata": {}, "source": [ "### DatasetParam\n", "\n", "Use dataset_name to specify the module name of your dataset, and fill in its parameters behind, these parameters will be passed to the \\_\\_init\\_\\_ interface of your dataset. **Please notice that your Dataset parameters need to be JSON-serializable, otherwise they cannot be parsed by the pipeline.**" ] }, { "cell_type": "code", "execution_count": 36, "id": "21875f9f", "metadata": {}, "outputs": [], "source": [ "from pipeline.component.nn import DatasetParam\n", "\n", "dataset_param = DatasetParam(dataset_name='mnist_dataset', flatten_feature=True) # specify dataset, and its init parameters" ] }, { "cell_type": "code", "execution_count": 37, "id": "de9917a7", "metadata": {}, "outputs": [], "source": [ "from pipeline.component.homo_nn import TrainerParam # Interface\n", "\n", "# our simple classification model:\n", "model = t.nn.Sequential(\n", " t.nn.Linear(784, 32),\n", " t.nn.ReLU(),\n", " t.nn.Linear(32, 10),\n", " t.nn.Softmax(dim=1)\n", ")\n", "\n", "nn_component = HomoNN(name='nn_0',\n", " model=model, # model\n", " loss=t.nn.CrossEntropyLoss(), # loss\n", " optimizer=t.optim.Adam(model.parameters(), lr=0.01), # optimizer\n", " dataset=dataset_param, # dataset\n", " trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=1024, validation_freqs=1),\n", " torch_seed=100 # random seed\n", " )" ] }, { "cell_type": "code", "execution_count": 38, "id": "62361f34", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.add_component(reader_0)\n", "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))\n", "pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))" ] }, { "cell_type": "code", "execution_count": 39, "id": "1fa46219", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[32m2022-12-19 16:26:20.771\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mJob id is 202212191626190908350\n", "\u001b[0m\n", "\u001b[32m2022-12-19 16:26:20.805\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:00\u001b[0m\n", "\u001b[32m2022-12-19 16:26:21.840\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:01\u001b[0m\n", "\u001b[32m2022-12-19 16:26:22.865\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:02\u001b[0m\n", "\u001b[32m2022-12-19 16:26:23.899\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:03\u001b[0m\n", "\u001b[32m2022-12-19 16:26:25.292\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:04\u001b[0m\n", "\u001b[32m2022-12-19 16:26:26.322\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:05\u001b[0m\n", "\u001b[32m2022-12-19 16:26:27.347\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:06\u001b[0m\n", "\u001b[32m2022-12-19 16:26:28.377\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:07\u001b[0m\n", "\u001b[32m2022-12-19 16:26:29.411\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:08\u001b[0m\n", "\u001b[32m2022-12-19 16:26:30.448\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m98\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KJob is still waiting, time elapse: 0:00:09\u001b[0m\n", "\u001b[0mm2022-12-19 16:26:32.538\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 16:26:32.542\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:11\u001b[0m\n", "\u001b[32m2022-12-19 16:26:33.582\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:12\u001b[0m\n", "\u001b[32m2022-12-19 16:26:34.617\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:13\u001b[0m\n", "\u001b[32m2022-12-19 16:26:35.668\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:14\u001b[0m\n", "\u001b[32m2022-12-19 16:26:36.723\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:15\u001b[0m\n", "\u001b[32m2022-12-19 16:26:37.766\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:16\u001b[0m\n", "\u001b[32m2022-12-19 16:26:38.852\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:18\u001b[0m\n", "\u001b[32m2022-12-19 16:26:39.969\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:19\u001b[0m\n", "\u001b[32m2022-12-19 16:26:41.040\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:20\u001b[0m\n", "\u001b[32m2022-12-19 16:26:42.137\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:21\u001b[0m\n", "\u001b[32m2022-12-19 16:26:43.181\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component reader_0, time elapse: 0:00:22\u001b[0m\n", "\u001b[0mm2022-12-19 16:26:44.253\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 16:26:44.258\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:23\u001b[0m\n", "\u001b[32m2022-12-19 16:26:45.346\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:24\u001b[0m\n", "\u001b[32m2022-12-19 16:26:46.423\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:25\u001b[0m\n", "\u001b[32m2022-12-19 16:26:47.463\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:26\u001b[0m\n", "\u001b[32m2022-12-19 16:26:48.529\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:27\u001b[0m\n", "\u001b[32m2022-12-19 16:26:49.662\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:28\u001b[0m\n", "\u001b[32m2022-12-19 16:26:50.739\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:29\u001b[0m\n", "\u001b[32m2022-12-19 16:26:51.778\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:31\u001b[0m\n", "\u001b[32m2022-12-19 16:26:52.814\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:32\u001b[0m\n", "\u001b[32m2022-12-19 16:26:53.849\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:33\u001b[0m\n", "\u001b[32m2022-12-19 16:26:54.892\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:34\u001b[0m\n", "\u001b[32m2022-12-19 16:26:55.937\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:35\u001b[0m\n", "\u001b[32m2022-12-19 16:26:56.978\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:36\u001b[0m\n", "\u001b[32m2022-12-19 16:26:58.019\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:37\u001b[0m\n", "\u001b[32m2022-12-19 16:26:59.070\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:38\u001b[0m\n", "\u001b[32m2022-12-19 16:27:00.113\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:39\u001b[0m\n", "\u001b[32m2022-12-19 16:27:01.236\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:40\u001b[0m\n", "\u001b[32m2022-12-19 16:27:02.274\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:41\u001b[0m\n", "\u001b[32m2022-12-19 16:27:03.397\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:42\u001b[0m\n", "\u001b[32m2022-12-19 16:27:04.466\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:43\u001b[0m\n", "\u001b[32m2022-12-19 16:27:05.501\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:44\u001b[0m\n", "\u001b[32m2022-12-19 16:27:06.544\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:45\u001b[0m\n", "\u001b[32m2022-12-19 16:27:07.632\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:46\u001b[0m\n", "\u001b[32m2022-12-19 16:27:08.675\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:47\u001b[0m\n", "\u001b[32m2022-12-19 16:27:09.860\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component nn_0, time elapse: 0:00:49\u001b[0m\n", "\u001b[0mm2022-12-19 16:27:12.111\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m125\u001b[0m - \u001b[1m\n", "\u001b[32m2022-12-19 16:27:12.118\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:51\u001b[0m\n", "\u001b[32m2022-12-19 16:27:13.185\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:52\u001b[0m\n", "\u001b[32m2022-12-19 16:27:14.237\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:53\u001b[0m\n", "\u001b[32m2022-12-19 16:27:15.346\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:54\u001b[0m\n", "\u001b[32m2022-12-19 16:27:16.514\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:55\u001b[0m\n", "\u001b[32m2022-12-19 16:27:17.566\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:56\u001b[0m\n", "\u001b[32m2022-12-19 16:27:18.625\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:57\u001b[0m\n", "\u001b[32m2022-12-19 16:27:19.676\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:58\u001b[0m\n", "\u001b[32m2022-12-19 16:27:20.729\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:00:59\u001b[0m\n", "\u001b[32m2022-12-19 16:27:21.832\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:01:01\u001b[0m\n", "\u001b[32m2022-12-19 16:27:22.912\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:01:02\u001b[0m\n", "\u001b[32m2022-12-19 16:27:23.967\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:01:03\u001b[0m\n", "\u001b[32m2022-12-19 16:27:25.023\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:01:04\u001b[0m\n", "\u001b[32m2022-12-19 16:27:26.094\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m127\u001b[0m - \u001b[1m\u001b[80D\u001b[1A\u001b[KRunning component eval_0, time elapse: 0:01:05\u001b[0m\n", "\u001b[32m2022-12-19 16:27:29.383\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m89\u001b[0m - \u001b[1mJob is success!!! Job id is 202212191626190908350\u001b[0m\n", "\u001b[32m2022-12-19 16:27:29.384\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mpipeline.utils.invoker.job_submitter\u001b[0m:\u001b[36mmonitor_job_status\u001b[0m:\u001b[36m90\u001b[0m - \u001b[1mTotal time: 0:01:08\u001b[0m\n" ] } ], "source": [ "pipeline.compile()\n", "pipeline.fit()" ] }, { "cell_type": "code", "execution_count": 40, "id": "0edf9014", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idlabelpredict_resultpredict_scorepredict_detailtype
0img_1000.9070178270339966{'0': 0.9070178270339966, '1': 0.0023874549660...train
1img_3460.19601570069789886{'0': 0.19484134018421173, '1': 0.044997252523...train
2img_4000.9618675112724304{'0': 0.9618675112724304, '1': 0.0010393995326...train
3img_5000.33044907450675964{'0': 0.33044907450675964, '1': 0.033256266266...train
4img_6770.3145765960216522{'0': 0.05851678550243378, '1': 0.075524508953...train
.....................
1304img_32537180.20599651336669922{'0': 0.080563984811306, '1': 0.12380836158990...train
1305img_32558180.20311488211154938{'0': 0.07224143296480179, '1': 0.130610913038...train
1306img_32563180.2071550488471985{'0': 0.06843454390764236, '1': 0.129064396023...train
1307img_32565150.29367145895957947{'0': 0.05658009275794029, '1': 0.086584843695...train
1308img_32573180.199515700340271{'0': 0.08787216246128082, '1': 0.127247273921...train
\n", "

1309 rows × 6 columns

\n", "
" ], "text/plain": [ " id label predict_result predict_score \\\n", "0 img_1 0 0 0.9070178270339966 \n", "1 img_3 4 6 0.19601570069789886 \n", "2 img_4 0 0 0.9618675112724304 \n", "3 img_5 0 0 0.33044907450675964 \n", "4 img_6 7 7 0.3145765960216522 \n", "... ... ... ... ... \n", "1304 img_32537 1 8 0.20599651336669922 \n", "1305 img_32558 1 8 0.20311488211154938 \n", "1306 img_32563 1 8 0.2071550488471985 \n", "1307 img_32565 1 5 0.29367145895957947 \n", "1308 img_32573 1 8 0.199515700340271 \n", "\n", " predict_detail type \n", "0 {'0': 0.9070178270339966, '1': 0.0023874549660... train \n", "1 {'0': 0.19484134018421173, '1': 0.044997252523... train \n", "2 {'0': 0.9618675112724304, '1': 0.0010393995326... train \n", "3 {'0': 0.33044907450675964, '1': 0.033256266266... train \n", "4 {'0': 0.05851678550243378, '1': 0.075524508953... train \n", "... ... ... \n", "1304 {'0': 0.080563984811306, '1': 0.12380836158990... train \n", "1305 {'0': 0.07224143296480179, '1': 0.130610913038... train \n", "1306 {'0': 0.06843454390764236, '1': 0.129064396023... train \n", "1307 {'0': 0.05658009275794029, '1': 0.086584843695... train \n", "1308 {'0': 0.08787216246128082, '1': 0.127247273921... train \n", "\n", "[1309 rows x 6 columns]" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.get_component('nn_0').get_output_data()" ] }, { "cell_type": "code", "execution_count": 41, "id": "8592212b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'best_epoch': 1,\n", " 'loss_history': [3.58235876026547, 3.4448592824914055],\n", " 'metrics_summary': {'train': {'accuracy': [0.25668449197860965,\n", " 0.4950343773873186],\n", " 'precision': [0.3708616690797323, 0.5928620913124757],\n", " 'recall': [0.21817632850241547, 0.4855654369784805]}},\n", " 'need_stop': False}" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.get_component('nn_0').get_summary()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "vscode": { "interpreter": { "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853" } } }, "nbformat": 4, "nbformat_minor": 5 }