{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "1c8b9b57", "metadata": {}, "source": [ "# Using Frozen Parameters Bert for Sentiment Classification\n", "\n", "In this example, we will construct a text classifier with parameter frozen bert, and train in on the IMDB sentimental classifcation dataset\n", "\n", "## Dataset: IMDB Sentimental\n", "\n", "This is an binary classification dataset, you can download our processed dataset from here: \n", "- https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/IMDB.csv\n", "and place it in the examples/data folder. \n", "\n", "The orgin data is from: \n", "- https://ai.stanford.edu/~amaas/data/sentiment/" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d0b7757b", "metadata": {}, "source": [ "## Check dataset" ] }, { "cell_type": "code", "execution_count": 9, "id": "5b87e19e", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "df = pd.read_csv('../../../../examples/data/IMDB.csv')" ] }, { "cell_type": "code", "execution_count": 10, "id": "718f82a0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
00One of the other reviewers has mentioned that ...1
11A wonderful little production. <br /><br />The...1
22I thought this was a wonderful way to spend ti...1
33Basically there's a family where a little boy ...0
44Petter Mattei's \"Love in the Time of Money\" is...1
19961996THE CELL (2000) Rating: 8/10<br /><br />The Ce...1
19971997This movie, despite its list of B, C, and D li...0
19981998I loved this movie! It was all I could do not ...1
19991999This was the worst movie I have ever seen Bill...0
20002000Stranded in Space (1972) MST3K version - a ver...0
\n", "

2001 rows × 3 columns

\n", "
" ], "text/plain": [ " id text label\n", "0 0 One of the other reviewers has mentioned that ... 1\n", "1 1 A wonderful little production.

The... 1\n", "2 2 I thought this was a wonderful way to spend ti... 1\n", "3 3 Basically there's a family where a little boy ... 0\n", "4 4 Petter Mattei's \"Love in the Time of Money\" is... 1\n", "... ... ... ...\n", "1996 1996 THE CELL (2000) Rating: 8/10

The Ce... 1\n", "1997 1997 This movie, despite its list of B, C, and D li... 0\n", "1998 1998 I loved this movie! It was all I could do not ... 1\n", "1999 1999 This was the worst movie I have ever seen Bill... 0\n", "2000 2000 Stranded in Space (1972) MST3K version - a ver... 0\n", "\n", "[2001 rows x 3 columns]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 11, "id": "a22c8db4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-12-25 23:19:45.537897: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", "2022-12-25 23:19:45.537936: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" ] } ], "source": [ "from federatedml.nn.dataset.nlp_tokenizer import TokenizerDataset" ] }, { "cell_type": "code", "execution_count": 13, "id": "5217d58c", "metadata": {}, "outputs": [], "source": [ "ds = TokenizerDataset(tokenizer_name_or_path=\"bert-base-uncased\")\n", "ds.load('../../../../examples/data/IMDB.csv')" ] }, { "cell_type": "code", "execution_count": 14, "id": "4f54db71", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "dl = DataLoader(ds, batch_size=16)\n", "for i in dl:\n", " break" ] }, { "attachments": {}, "cell_type": "markdown", "id": "cc803046", "metadata": {}, "source": [ "## Build A Bert Classifier" ] }, { "cell_type": "code", "execution_count": 15, "id": "6369d0d0", "metadata": {}, "outputs": [], "source": [ "from pipeline.component.nn import save_to_fate" ] }, { "cell_type": "code", "execution_count": 19, "id": "870e24eb", "metadata": {}, "outputs": [], "source": [ "%%save_to_fate model bert_.py\n", "\n", "import torch as t\n", "from federatedml.nn.model_zoo.pretrained_bert import PretrainedBert\n", "\n", "\n", "class BertClassifier(t.nn.Module):\n", " \n", " def __init__(self, ):\n", " super(BertClassifier, self).__init__()\n", " self.bert = PretrainedBert(pretrained_model_name_or_path='bert-base-uncased', freeze_weight=True)\n", " self.classifier = t.nn.Sequential(\n", " t.nn.Linear(768, 128),\n", " t.nn.ReLU(),\n", " t.nn.Linear(128, 64),\n", " t.nn.ReLU(),\n", " t.nn.Linear(64, 1),\n", " t.nn.Sigmoid()\n", " )\n", " \n", " def parameters(self, ):\n", " return self.classifier.parameters()\n", " \n", " def forward(self, x):\n", " x = self.bert(x)\n", " return self.classifier(x.pooler_output)" ] }, { "cell_type": "code", "execution_count": 1, "id": "66d039bb", "metadata": {}, "outputs": [], "source": [ "model = BertClassifier()" ] }, { "cell_type": "code", "execution_count": 21, "id": "23fd5e83", "metadata": {}, "outputs": [], "source": [ "import torch as t\n", "from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer\n", "\n", "trainer = FedAVGTrainer(epochs=3, batch_size=16, shuffle=True, data_loader_worker=4)\n", "trainer.local_mode()\n", "trainer.set_model(model)" ] }, { "cell_type": "code", "execution_count": 22, "id": "e2b56177", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "epoch is 0\n", "100%|██████████| 126/126 [01:21<00:00, 1.55it/s]\n", "epoch loss is 0.6995822169195706\n", "epoch is 1\n", "100%|██████████| 126/126 [01:17<00:00, 1.63it/s]\n", "epoch loss is 0.6738948538445163\n", "epoch is 2\n", "100%|██████████| 126/126 [01:16<00:00, 1.64it/s]\n", "epoch loss is 0.6501996349180299\n" ] } ], "source": [ "opt = t.optim.Adam(model.parameters(), lr=0.005)\n", "loss = t.nn.BCELoss()\n", "# local test\n", "trainer.train(ds, None, opt, loss)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "18f3b417", "metadata": {}, "source": [ "## Submit a pipeline" ] }, { "cell_type": "code", "execution_count": 28, "id": "0f3f9f1d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'namespace': 'experiment', 'table_name': 'imdb'}" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch as t\n", "from torch import nn\n", "from pipeline import fate_torch_hook\n", "from pipeline.component import HomoNN\n", "from pipeline.backend.pipeline import PipeLine\n", "from pipeline.component import Reader, Evaluation, DataTransform\n", "from pipeline.interface import Data, Model\n", "\n", "fate_torch_hook(t)\n", "\n", "\n", "import os\n", "fate_project_path = os.path.abspath('../../../../')\n", "guest_0 = 10000\n", "host_1 = 9999\n", "pipeline = PipeLine().set_initiator(role='guest', party_id=guest_0).set_roles(guest=guest_0, host=host_1,\n", " arbiter=guest_0)\n", "data_0 = {\"name\": \"imdb\", \"namespace\": \"experiment\"}\n", "data_path = fate_project_path + '/examples/data/IMDB.csv'\n", "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)\n", "pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path)" ] }, { "cell_type": "code", "execution_count": 29, "id": "ba23ac18", "metadata": {}, "outputs": [], "source": [ "reader_0 = Reader(name=\"reader_0\")\n", "reader_0.get_party_instance(role='guest', party_id=guest_0).component_param(table=data_0)\n", "reader_0.get_party_instance(role='host', party_id=host_1).component_param(table=data_0)\n", "\n", "reader_1 = Reader(name=\"reader_1\")\n", "reader_1.get_party_instance(role='guest', party_id=guest_0).component_param(table=data_0)\n", "reader_1.get_party_instance(role='host', party_id=host_1).component_param(table=data_0)" ] }, { "cell_type": "code", "execution_count": 30, "id": "4732d629", "metadata": {}, "outputs": [], "source": [ "from pipeline.component.homo_nn import DatasetParam, TrainerParam \n", "model = t.nn.Sequential(\n", " t.nn.CustModel(module_name='bert_', class_name='BertClassifier')\n", ")\n", "\n", "nn_component = HomoNN(name='nn_0',\n", " model=model, \n", " loss=t.nn.BCELoss(),\n", " optimizer = t.optim.Adam(lr=0.001, weight_decay=0.001),\n", " dataset=DatasetParam(dataset_name='nlp_tokenizer', tokenizer_name_or_path=\"bert-base-uncased\"), # 使用自定义的dataset\n", " trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=16, data_loader_worker=8, cuda=True),\n", " torch_seed=100 \n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "41329be5", "metadata": {}, "outputs": [], "source": [ "pipeline.add_component(reader_1)\n", "pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data, validate_data=reader_1.output.data))\n", "pipeline.add_component(Evaluation(name='eval_0', eval_type='binary'), data=Data(data=nn_component.output.data))\n", "pipeline.compile()\n", "pipeline.fit()" ] } ], "metadata": { "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 11:38:47) \n[GCC 7.5.0]" }, "vscode": { "interpreter": { "hash": "d29574a2ab71ec988cdcd4d29c58400bd2037cad632b9528d973466f7fb6f853" } } }, "nbformat": 4, "nbformat_minor": 5 }