{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "0d19012c", "metadata": {}, "source": [ "# Hetero NN: A federated task with guest using image data and host using text data\n", "\n", "In this task, we will show you how to build a federated task under Hetero-NN, in which the participating parties use different structured data: the guest party has image data and labels, and the host party has text, and together they complete a binary classification task. The tutorial dataset is built by flickr 8k, and labels 0 and 1 indicate whether the image is in the wilderness or in the city. You can download the processed dataset from here and put it under examples/data. The complete dataset can be downloaded from here. (Please note that the original dataset is different from the data in this example, and this dataset is annotated with a small portion of the complete dataset for demonstration purposes.)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "08010311", "metadata": {}, "source": [ "## Get the example dataset:\n", "\n", "Please down load the dataset from:\n", "- https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/flicker_toy_data.zip\n", "and put it under /examples/data folder\n", "\n", "The origin of this dataset is the flickr-8k dataset from:\n", "- https://www.kaggle.com/datasets/adityajn105/flickr8k" ] }, { "cell_type": "code", "execution_count": 2, "id": "15b81e88", "metadata": {}, "outputs": [], "source": [ "from pipeline.component.nn import save_to_fate" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7bdde3a2", "metadata": {}, "source": [ "### Guest Bottom Model" ] }, { "cell_type": "code", "execution_count": 3, "id": "98e9a771", "metadata": {}, "outputs": [], "source": [ "%%save_to_fate model guest_bottom_image.py\n", "from torch import nn\n", "import torch as t\n", "from torch.nn import functional as F\n", "\n", "class ImgBottomNet(nn.Module):\n", " def __init__(self):\n", " super(ImgBottomNet, self).__init__()\n", " self.seq = t.nn.Sequential(\n", " nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),\n", " nn.MaxPool2d(kernel_size=3),\n", " nn.Conv2d(in_channels=6, out_channels=6, kernel_size=3),\n", " nn.AvgPool2d(kernel_size=5)\n", " )\n", " \n", " self.fc = t.nn.Sequential(\n", " nn.Linear(1176, 32),\n", " nn.ReLU(),\n", " nn.Linear(32, 8)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.seq(x)\n", " x = x.flatten(start_dim=1)\n", " x = self.fc(x)\n", " return x\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b55fce67", "metadata": {}, "source": [ "## Guest Top Model" ] }, { "cell_type": "code", "execution_count": 4, "id": "f2a81854", "metadata": {}, "outputs": [], "source": [ "%%save_to_fate model guest_top_image.py\n", "from torch import nn\n", "import torch as t\n", "from torch.nn import functional as F\n", "\n", "class ImgTopNet(nn.Module):\n", " def __init__(self):\n", " super(ImgTopNet, self).__init__()\n", " \n", " self.fc = t.nn.Sequential(\n", " nn.Linear(4, 1),\n", " nn.Sigmoid()\n", " )\n", "\n", " def forward(self, x):\n", " x = self.fc(x)\n", " return x.flatten()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "abea2f8e", "metadata": {}, "source": [ "### Host Bottom Model" ] }, { "cell_type": "code", "execution_count": 5, "id": "84d92fa9", "metadata": {}, "outputs": [], "source": [ "%%save_to_fate model host_bottom_lstm.py\n", "from torch import nn\n", "import torch as t\n", "from torch.nn import functional as F\n", "\n", "class LSTMBottom(nn.Module):\n", " \n", " def __init__(self, vocab_size):\n", " super(LSTMBottom, self).__init__()\n", " self.word_embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=16, padding_idx=0)\n", " self.lstm = t.nn.Sequential(\n", " nn.LSTM(input_size=16, hidden_size=16, num_layers=2, batch_first=True)\n", " )\n", " self.act = nn.ReLU()\n", " self.linear = nn.Linear(16, 8)\n", "\n", " def forward(self, x):\n", " embeddings = self.word_embed(x)\n", " lstm_fw, _ = self.lstm(embeddings)\n", " \n", " return self.act(self.linear(lstm_fw.sum(dim=1))) " ] }, { "attachments": {}, "cell_type": "markdown", "id": "b67515a5", "metadata": {}, "source": [ "### Locally test dataset and bottom model" ] }, { "cell_type": "code", "execution_count": 6, "id": "e6bca16b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-26 20:45:42.535744: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", "2022-12-26 20:45:42.535777: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" ] } ], "source": [ "from federatedml.nn.dataset.image import ImageDataset\n", "from federatedml.nn.dataset.nlp_tokenizer import TokenizerDataset" ] }, { "cell_type": "code", "execution_count": 7, "id": "f3495837", "metadata": {}, "outputs": [], "source": [ "# flicke image\n", "img_ds = ImageDataset(center_crop=True, center_crop_shape=(224, 224), return_label=True) # return label = True\n", "img_ds.load('../../../../examples/data/flicker_toy_data/flicker/images/')\n", "# text\n", "txt_ds = TokenizerDataset(return_label=False) \n", "txt_ds.load('../../../../examples/data/flicker_toy_data/text.csv')" ] }, { "cell_type": "code", "execution_count": 8, "id": "7542d6d5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "215\n", "(tensor([[[0.5059, 0.5176, 0.5137, ..., 0.4941, 0.5020, 0.5059],\n", " [0.4980, 0.5020, 0.4980, ..., 0.4824, 0.5020, 0.5059],\n", " [0.5059, 0.4863, 0.4902, ..., 0.4980, 0.4980, 0.5137],\n", " ...,\n", " [0.7843, 0.7922, 0.7529, ..., 0.1412, 0.2078, 0.2196],\n", " [0.9922, 0.9922, 0.9647, ..., 0.1176, 0.0941, 0.1333],\n", " [0.9961, 0.9922, 1.0000, ..., 0.1647, 0.1294, 0.1373]],\n", "\n", " [[0.5765, 0.5882, 0.5843, ..., 0.5490, 0.5569, 0.5608],\n", " [0.5686, 0.5804, 0.5765, ..., 0.5490, 0.5529, 0.5529],\n", " [0.5608, 0.5569, 0.5647, ..., 0.5569, 0.5490, 0.5529],\n", " ...,\n", " [0.7961, 0.8039, 0.7490, ..., 0.1373, 0.1882, 0.2000],\n", " [0.9961, 0.9961, 0.9608, ..., 0.1137, 0.1137, 0.1529],\n", " [0.9922, 0.9922, 1.0000, ..., 0.1608, 0.1059, 0.1216]],\n", "\n", " [[0.6235, 0.6353, 0.6314, ..., 0.5922, 0.6000, 0.6118],\n", " [0.6078, 0.6235, 0.6196, ..., 0.5804, 0.5882, 0.6000],\n", " [0.6039, 0.6118, 0.6196, ..., 0.5843, 0.5843, 0.6000],\n", " ...,\n", " [0.5882, 0.5961, 0.5686, ..., 0.1216, 0.1765, 0.1882],\n", " [0.7294, 0.7373, 0.7373, ..., 0.0980, 0.0980, 0.1294],\n", " [0.8745, 0.8431, 0.8627, ..., 0.1451, 0.1059, 0.1176]]]), tensor(0))\n", "[0, 1]\n", "['1022454428_b6b660a67b', '103195344_5d2dc613a3', '1055753357_4fa3d8d693', '1124448967_2221af8dc5', '1131804997_177c3c0640', '1138784872_69ade3f2ab', '1142847777_2a0c1c2551', '1143373711_2e90b7b799', '1143882946_1898d2eeb9', '1144288288_e5c9558b6a']\n" ] } ], "source": [ "print(len(img_ds))\n", "print(img_ds[0])\n", "print(img_ds.get_classes())\n", "print(img_ds.get_sample_ids()[0: 10])" ] }, { "cell_type": "code", "execution_count": 9, "id": "6fae43a6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "215\n", "tensor([ 101, 1037, 2158, 1998, 2450, 2729, 2005, 2019, 10527, 2247,\n", " 1996, 2217, 1997, 1037, 2303, 1997, 2300, 1012, 102, 0,\n", " 0, 0, 0, 0, 0, 0])\n", "30522\n" ] } ], "source": [ "print(len(txt_ds))\n", "print(txt_ds[0]) # word idx\n", "print(txt_ds.get_vocab_size()) # vocab size" ] }, { "cell_type": "code", "execution_count": 10, "id": "c857db34", "metadata": {}, "outputs": [], "source": [ "img_bottom = ImgBottomNet()\n", "lstm_bottom = LSTMBottom(vocab_size=txt_ds.get_vocab_size())" ] }, { "cell_type": "code", "execution_count": 11, "id": "c4fd9b82", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0000, 1.8284, 0.0000, 0.0000, 2.3009, 0.0626, 0.0678, 0.0000],\n", " [0.0369, 1.8046, 0.0000, 0.0000, 2.4555, 0.0000, 0.0000, 0.0000]],\n", " grad_fn=)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lstm_bottom(t.vstack([txt_ds[0], txt_ds[1]])) # test forward" ] }, { "cell_type": "code", "execution_count": 12, "id": "9add8cc6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.0987, 0.0808, -0.0140, -0.0718, -0.1381, 0.2642, -0.1874, -0.0494],\n", " [ 0.0856, 0.0948, -0.0362, -0.0702, -0.0695, 0.2293, -0.1768, -0.0638]],\n", " grad_fn=)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img_bottom(t.vstack([img_ds[0][0].unsqueeze(dim=0), img_ds[1][0].unsqueeze(dim=0)])) " ] }, { "attachments": {}, "cell_type": "markdown", "id": "9950c6bd", "metadata": {}, "source": [ "### Pipeline" ] }, { "cell_type": "code", "execution_count": 13, "id": "d4a51493", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'namespace': 'experiment', 'table_name': 'flicker_host'}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "import torch as t\n", "from torch import nn\n", "from pipeline import fate_torch_hook\n", "from pipeline.component import HeteroNN\n", "from pipeline.component.hetero_nn import DatasetParam\n", "from pipeline.backend.pipeline import PipeLine\n", "from pipeline.component import Reader, Evaluation, DataTransform\n", "from pipeline.interface import Data, Model\n", "from pipeline.component.nn import save_to_fate\n", "\n", "fate_torch_hook(t)\n", "\n", "fate_project_path = os.path.abspath('../../../../')\n", "guest = 10000\n", "host = 9999\n", "\n", "pipeline_mix = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)\n", "\n", "guest_data = {\"name\": \"flicker_guest\", \"namespace\": \"experiment\"}\n", "host_data = {\"name\": \"flicker_host\", \"namespace\": \"experiment\"}\n", "\n", "guest_data_path = fate_project_path + '/examples/data/flicker_toy_data/flicker/images'\n", "host_data_path = fate_project_path + '/examples/data/flicker_toy_data/text.csv'\n", "\n", "pipeline_mix.bind_table(name='flicker_guest', namespace='experiment', path=guest_data_path)\n", "pipeline_mix.bind_table(name='flicker_host', namespace='experiment', path=host_data_path)" ] }, { "cell_type": "code", "execution_count": 14, "id": "50efe200", "metadata": {}, "outputs": [], "source": [ "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=guest_data)\n", "reader_0.get_party_instance(role='host', party_id=host).component_param(table=host_data)" ] }, { "cell_type": "code", "execution_count": 15, "id": "4475a968", "metadata": {}, "outputs": [], "source": [ "hetero_nn_0 = HeteroNN(name=\"hetero_nn_0\", epochs=5,\n", " interactive_layer_lr=0.001, batch_size=64, validation_freqs=1, task_type='classification')\n", "guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)\n", "host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)" ] }, { "cell_type": "code", "execution_count": 16, "id": "a2a591e4", "metadata": {}, "outputs": [], "source": [ "\n", "guest_bottom = t.nn.Sequential(\n", " nn.CustModel(module_name='guest_bottom_image', class_name='ImgBottomNet')\n", ")\n", "\n", "guest_top = t.nn.Sequential(\n", " nn.CustModel(module_name='guest_top_image', class_name='ImgTopNet')\n", ")\n", "# bottom model\n", "host_bottom = nn.CustModel(module_name='host_bottom_lstm', class_name='LSTMBottom', vocab_size=txt_ds.get_vocab_size())\n", "\n", "interactive_layer = t.nn.InteractiveLayer(out_dim=4, guest_dim=8, host_dim=8, host_num=1)" ] }, { "cell_type": "code", "execution_count": 18, "id": "b8799751", "metadata": {}, "outputs": [], "source": [ "guest_nn_0.add_top_model(guest_top)\n", "guest_nn_0.add_bottom_model(guest_bottom)\n", "host_nn_0.add_bottom_model(host_bottom)\n", "optimizer = t.optim.Adam(lr=0.001)\n", "loss = t.nn.BCELoss()\n", "\n", "hetero_nn_0.set_interactive_layer(interactive_layer)\n", "hetero_nn_0.compile(optimizer=optimizer, loss=loss)" ] }, { "cell_type": "code", "execution_count": 19, "id": "ea5dae1c", "metadata": {}, "outputs": [], "source": [ "# 添加dataset\n", "guest_nn_0.add_dataset(DatasetParam(dataset_name='image', return_label=True, center_crop=True, center_crop_shape=(224, 224), label_dtype='float'))\n", "host_nn_0.add_dataset(DatasetParam(dataset_name='nlp_tokenizer', return_label=False))" ] }, { "cell_type": "code", "execution_count": 20, "id": "aa06e1bf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipeline_mix.add_component(reader_0)\n", "pipeline_mix.add_component(hetero_nn_0, data=Data(train_data=reader_0.output.data))\n", "pipeline_mix.compile()" ] }, { "cell_type": "code", "execution_count": 22, "id": "bdc60285", "metadata": {}, "outputs": [], "source": [ "pipeline_mix.fit()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "vscode": { "interpreter": { "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853" } } }, "nbformat": 4, "nbformat_minor": 5 }