{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Renset Classification on CIFAR 10\n", "\n", "In this example, we show you how to use torchvision model to make a federated classification task" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset: CIFAR 10\n", "\n", "You can download the CIFAR-10 dataset through this link: \n", "- https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/cifar-10.zip\n", " \n", "The origin CIFAR-10 is from: \n", "- https://www.cs.toronto.edu/~kriz/cifar.html\n", " \n", "For the convinence of demonstrate, our clients will use same dataset" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Local Test\n", "\n", "Firstly we locally test our model and dataset. If it works, we can submit a federated task." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from pipeline.component.nn import save_to_fate" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "%%save_to_fate model resnet.py\n", "\n", "# model\n", "import torch as t\n", "from torch import nn\n", "from torchvision.models import resnet18, ResNet18_Weights\n", "\n", "class Resnet(nn.Module):\n", "\n", " def __init__(self, ):\n", " super(Resnet, self).__init__()\n", " self.resnet = resnet18()\n", " self.classifier = t.nn.Linear(1000, 10)\n", " self.softmax = nn.Softmax(dim=1)\n", "\n", " def forward(self, x):\n", " if self.training:\n", " return self.classifier(self.resnet(x))\n", " else:\n", " return self.softmax(self.classifier(self.resnet(x)))\n", " " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Resnet(\n", " (resnet): ResNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc): Linear(in_features=512, out_features=1000, bias=True)\n", " )\n", " (classifier): Linear(in_features=1000, out_features=10, bias=True)\n", " (softmax): Softmax(dim=1)\n", ")\n" ] } ], "source": [ "model = Resnet()\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# read dataset\n", "from federatedml.nn.dataset.image import ImageDataset\n", "\n", "ds = ImageDataset()\n", "ds.load('../../../../examples/data/cifar10/train/')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 32, 32])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ds[0][0].shape" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# local test\n", "import torch as t\n", "from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer\n", "\n", "trainer = FedAVGTrainer(epochs=1, batch_size=1024, data_loader_worker=4)\n", "trainer.set_model(model)\n", "\n", "optimizer = t.optim.Adam(model.parameters(), lr=0.001)\n", "loss = t.nn.CrossEntropyLoss()\n", "\n", "trainer.local_mode() # set local mode\n", "trainer.train(ds, None, optimizer, loss)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Submit a federated task" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'namespace': 'experiment', 'table_name': 'cifar10'}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch as t\n", "from torch import nn\n", "from pipeline import fate_torch_hook\n", "from pipeline.component import HomoNN\n", "from pipeline.backend.pipeline import PipeLine\n", "from pipeline.component import Reader, Evaluation, DataTransform\n", "from pipeline.interface import Data, Model\n", "\n", "fate_torch_hook(t)\n", "\n", "import os\n", "fate_project_path = os.path.abspath('../../../../')\n", "guest = 10000\n", "host = 9999\n", "pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,\n", " arbiter=host)\n", "data_0 = {\"name\": \"cifar10\", \"namespace\": \"experiment\"}\n", "data_path = fate_project_path + '/examples/data/cifar10/train'\n", "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)\n", "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)\n", "reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_0)\n", "\n", "reader_1 = Reader(name=\"reader_1\")\n", "reader_1.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)\n", "reader_1.get_party_instance(role='host', party_id=host).component_param(table=data_0)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from pipeline.component.homo_nn import DatasetParam, TrainerParam\n", "\n", "model = t.nn.Sequential(\n", " t.nn.CustModel(module_name='resnet', class_name='Resnet')\n", ")\n", "\n", "nn_component = HomoNN(name='nn_0',\n", " model=model, \n", " loss=t.nn.CrossEntropyLoss(),\n", " optimizer = t.optim.Adam(lr=0.001, weight_decay=0.001),\n", " dataset=DatasetParam(dataset_name='image'), # 使用自定义的dataset\n", " trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=10, batch_size=1024, data_loader_worker=8),\n", " torch_seed=100\n", " )" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.add_component(reader_0)\n", "pipeline.add_component(reader_1)\n", "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data, validate_data=reader_1.output.data))\n", "pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline.compile()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pipeline.fit() # submit pipeline here" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853" } } }, "nbformat": 4, "nbformat_minor": 2 }