|
1 year ago | |
---|---|---|
assets | 1 year ago | |
fed_utils | 1 year ago | |
templates | 1 year ago | |
utils | 1 year ago | |
.gitignore | 1 year ago | |
GlobalModel_generated.py | 1 year ago | |
LICENSE | 1 year ago | |
README.md | 1 year ago | |
clients_datasets.py | 1 year ago | |
main.py | 1 year ago | |
new-databricks-dolly-15k.json | 1 year ago |
Overview • Installation • Data_Preparation • Federated_Finetuning • Inference • Citation
Recent advancements in fine-tuning large language models (LLMs) have leveraged instructions created by humans or APIs (such as ChatGPT and GPT-4) to revolutionize NLP research and industry applications. However, the collection of instructions from a wide array of individuals presents challenges in privacy and heterogeneity. Federated Learning, a well-studied and well-developed learning approach, provides a solution to addresses these challenges and paves the way for designing personalized LLMs tailored to individual users.
This repository offers a foundational framework for exploring federated fine-tuning of LLMs using heterogeneous instructions across diverse categories. The framework is designed for ease of use, adaptability, and scalability to accommodate large datasets. Additionally, it facilitates seamless integration of novel algorithms and configurations, making it a convenient tool for both researchers and practitioners in the NLP community.
The code requires some dependencies (Python=3.8) as specified in requirements.txt
. Please follow the relevant libraries to install or run:
pip install -r requirements.txt
If bitsandbytes
doesn't work, install it from source. Windows users can follow these instructions.
Prior to commencing the federated fine-tuning, make sure to create a data file for each individual client.
num_client=10 # The number of clients
diff_quantity=0 # Whether clients have different amounts of data
python clients_datasets.py $num_client $diff_quantity
Running this command will save the data files in the folder ./data/str(num_client)
. The data file new-databricks-dolly-15k.json
for generating each client's local dataset is the first version of databricks-dolly-15k
, which is a corpus of more than 15,000 records with 8 categeries generated by thousands of Databricks Lab employees. Please refer to their official repository dolly for the latest version of data.
The first version of databricks-dolly-15k
contains 8 Categories, with the distribution of each category shown in the following figure.
The following table presents an illustrative depiction of the category distributions among each client, serving to exemplify the diverse nature of clients' instructions
Open_qa | General_qa | Classification | Closed_qa | Brainstorming | Information_extraction | Summarization | Creative_writing | |
---|---|---|---|---|---|---|---|---|
Client 0 | 0 | 0 | 149 | 598 | 0 | 0 | 746 | 0 |
Client 1 | 747 | 0 | 747 | 0 | 0 | 0 | 0 | 0 |
Client 2 | 377 | 747 | 0 | 0 | 0 | 370 | 0 | 0 |
Client 3 | 985 | 0 | 0 | 0 | 0 | 0 | 507 | 0 |
Client 4 | 0 | 0 | 0 | 747 | 0 | 747 | 0 | 0 |
Client 5 | 746 | 747 | 0 | 0 | 0 | 0 | 0 | 0 |
Client 6 | 0 | 362 | 0 | 0 | 747 | 385 | 0 | 0 |
Client 7 | 746 | 0 | 483 | 0 | 264 | 0 | 0 | 0 |
Client 8 | 0 | 325 | 0 | 468 | 0 | 0 | 0 | 701 |
Client 9 | 0 | 0 | 747 | 0 | 747 | 0 | 0 | 0 |
You can simply modify clients_datasets.py
to load your own dataset for federated training.
To fully leverage the computational resources of each participating client, our lightweight Federated Learning framework employs the well-established parameter-efficient method, LoRA, for conducting local training. The local training process is built upon the implementations of Hugging Face's PEFT, Tim Dettmers' bitsandbytes, and the Alpaca-lora, enabling the training to be completed within hours on a single NVIDIA TITAN RTX.
Example usage:
python main.py --global_model 'chavinlo/alpaca-native'\
--data_path "./data" \
--output_dir './lora-shepherd-7b/'\
--num_communication_rounds 10 \
--num_clients 10 \
--train_on_inputs \
--group_by_length
Within the main.py
file, the GeneralClient is a Python class serves as a representation of the local client and encompasses five distinct sections that facilitate local training: "prepare_local_dataset," "build_local_trainer," "initiate_local_training," "train," and "terminate_local_training." Each of these sections is easy to comprehend and can be easily customized by adding your own functions to meet specific requirements.
We can also tweak the hyperparameters:
python main.py --global_model 'chavinlo/alpaca-native'\
--data_path "./data" \
--output_dir './lora-shepherd-7b/'\
--num_communication_rounds 10 \
--num_clients 10 \
--client_selection_frac 0.05 \
--local_num_epochs 2 \
--local_batch_size 64 \
--local_micro_batch_size 32 \
--local_learning_rate 0.0003 \
--lora_r 8 \
--lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \
--train_on_inputs \
--group_by_length
Our framework supports numerous popular LLMs, such as LLaMA, Alpaca, Vicuna, Baize, and others. We welcome any pull requests that adapt our code to support additional models or datasets.
The GlobalModel_generate.py
file streamlines the inference process for the global model by utilizing a Gradio interface. This file loads the foundation model from the Hugging Face Model Hub and obtains the LoRA weights and configurations from the output directory.
python GlobalModel_generate.py \
--load_8bit \
--base_model 'chavinlo/alpaca-native' \
--lora_weights_path /output/path/to/lora_weights \
--lora_config_path /output/path/to/lora_config
Please cite this repo if you find our repository helpful for your research.
@misc{Shepherd,
author = {Jianyi Zhang, Martin Kuo, Ruiyi Zhang, Guoyin Wang, Saeed Vahidian, Yiran Chen },
title = {Shepherd: Large Language Models with Parameter-Efficient Federated Finetuning in the Presence of Heterogeneous Instructions},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/JayZhang42/FederatedGPT-Shepherd}},
}
We are constantly working to enhance this framework by resolving bugs and extending its functionality and simulation capabilities. We welcome pull requests that adapt our code to support additional research goals, such as benchmarking of models and datasets, algorithmic enhancements, and hardware simulation.