Browse Source

frist commit

JayZhang42 1 year ago
parent
commit
8c86fe9455

+ 225 - 0
GlobalModel_generated.py

@@ -0,0 +1,225 @@
+import os
+
+import fire
+import gradio as gr
+import torch
+import transformers
+
+from peft import (
+    PeftModel,
+    LoraConfig,
+    get_peft_model,
+    get_peft_model_state_dict,
+    prepare_model_for_int8_training,
+    set_peft_model_state_dict,
+)
+
+from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer,AutoTokenizer
+from utils.callbacks import Iteratorize, Stream
+from utils.prompter import Prompter
+if torch.cuda.is_available():
+    device = "cuda"
+else:
+    device = "cpu"
+
+try:
+    if torch.backends.mps.is_available():
+        device = "mps"
+except:
+    pass
+
+
+def main(
+    load_8bit: bool = False,
+    base_model: str = "",
+    lora_weights_path: str = "",
+    lora_config_path: str= "", # provide only the file path, excluding the file name 'adapter_config.json'
+    prompt_template: str = "",  # The prompt template to use, will default to alpaca.
+    server_name: str = "127.0.0.1",
+    share_gradio: bool = False,
+):
+    base_model = base_model or os.environ.get("BASE_MODEL", "")
+    assert (
+        base_model
+    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
+
+    prompter = Prompter(prompt_template)
+    tokenizer = LlamaTokenizer.from_pretrained(base_model)
+    if not lora_weights_path.endswith(".bin"):
+        if device == "cuda":
+            model = LlamaForCausalLM.from_pretrained(
+                base_model,
+                load_in_8bit=load_8bit,
+                torch_dtype=torch.float16,
+                device_map="auto",
+            )
+            model = PeftModel.from_pretrained(
+                model,
+                lora_weights_path,
+                torch_dtype=torch.float16,
+            )
+        elif device == "mps":
+            model = LlamaForCausalLM.from_pretrained(
+                base_model,
+                device_map={"": device},
+                torch_dtype=torch.float16,
+            )
+            model = PeftModel.from_pretrained(
+                model,
+                lora_weights_path,
+                device_map={"": device},
+                torch_dtype=torch.float16,
+            )
+        else:
+            model = LlamaForCausalLM.from_pretrained(
+                base_model, device_map={"": device}, low_cpu_mem_usage=True
+            )
+            model = PeftModel.from_pretrained(
+                model,
+                lora_weights_path,
+                device_map={"": device},
+            )
+    else:
+        model = LlamaForCausalLM.from_pretrained(
+            base_model,
+            load_in_8bit=True,
+            torch_dtype=torch.float16,
+            device_map="auto",
+        )
+        model = prepare_model_for_int8_training(model)
+        config = LoraConfig.from_pretrained(lora_config_path)
+        lora_weights = torch.load(lora_weights_path)
+        model = PeftModel(model, config)
+        set_peft_model_state_dict(model,lora_weights,"default")
+        del lora_weights
+
+
+    # unwind broken decapoda-research config
+    model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
+    model.config.bos_token_id = 1
+    model.config.eos_token_id = 2
+
+    if not load_8bit:
+        model.half()  # seems to fix bugs for some users.
+
+    model.eval()
+
+
+    def evaluate(
+        instruction,
+        input=None,
+        temperature=0.1,
+        top_p=0.75,
+        top_k=40,
+        num_beams=4,
+        max_new_tokens=128,
+        stream_output=True,
+        **kwargs,
+    ):
+        prompt = prompter.generate_prompt(instruction, input)
+        inputs = tokenizer(prompt, return_tensors="pt")
+        input_ids = inputs["input_ids"].to(device)
+        generation_config = GenerationConfig(
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            num_beams=num_beams,
+            **kwargs,
+        )
+
+        generate_params = {
+            "input_ids": input_ids,
+            "generation_config": generation_config,
+            "return_dict_in_generate": True,
+            "output_scores": True,
+            "max_new_tokens": max_new_tokens,
+        }
+
+        if stream_output:
+            # Stream the reply 1 token at a time.
+            # This is based on the trick of using 'stopping_criteria' to create an iterator,
+            # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
+
+            def generate_with_callback(callback=None, **kwargs):
+                kwargs.setdefault(
+                    "stopping_criteria", transformers.StoppingCriteriaList()
+                )
+                kwargs["stopping_criteria"].append(
+                    Stream(callback_func=callback)
+                )
+                with torch.no_grad():
+                    model.generate(**kwargs)
+
+            def generate_with_streaming(**kwargs):
+                return Iteratorize(
+                    generate_with_callback, kwargs, callback=None
+                )
+
+            with generate_with_streaming(**generate_params) as generator:
+                for output in generator:
+                    # new_tokens = len(output) - len(input_ids[0])
+                    decoded_output = tokenizer.decode(output)
+
+                    if output[-1] in [tokenizer.eos_token_id]:
+                        break
+
+                    yield prompter.get_response(decoded_output)
+            return  # early return for stream_output
+
+        # Without streaming
+        with torch.no_grad():
+            generation_output = model.generate(
+                input_ids=input_ids,
+                generation_config=generation_config,
+                return_dict_in_generate=True,
+                output_scores=True,
+                max_new_tokens=max_new_tokens,
+            )
+        s = generation_output.sequences[0]
+        output = tokenizer.decode(s)
+        yield prompter.get_response(output)
+
+    sherpherd_UI=gr.Interface(
+        fn=evaluate,
+        inputs=[
+            gr.components.Textbox(
+                lines=2,
+                label="Instruction",
+                placeholder="Tell me about alpacas.",
+            ),
+            gr.components.Textbox(lines=2, label="Input", placeholder="none"),
+            gr.components.Slider(
+                minimum=0, maximum=1, value=0.1, label="Temperature"
+            ),
+            gr.components.Slider(
+                minimum=0, maximum=1, value=0.75, label="Top p"
+            ),
+            gr.components.Slider(
+                minimum=0, maximum=100, step=1, value=40, label="Top k"
+            ),
+            gr.components.Slider(
+                minimum=1, maximum=4, step=1, value=4, label="Beams"
+            ),
+            gr.components.Slider(
+                minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
+            ),
+            gr.components.Checkbox(label="Stream output"),
+        ],
+        outputs=[
+            gr.inputs.Textbox(
+                lines=5,
+                label="Output",
+            )
+        ],
+        title="FederatedGPT-shepherd",
+        description="Shepherd is a LLM that has been fine-tuned in a federated manner ",
+    ).queue()
+
+    sherpherd_UI.launch(share=True)
+
+
+
+
+
+if __name__ == "__main__":
+    fire.Fire(main)

+ 152 - 1
README.md

@@ -1 +1,152 @@
-# FederatedGPT-Shepherd
+<h1 align="center">
+  <img src="assets/shepherd&llamas.png" width="75%">
+  <br>
+  Shepherd
+  <br>
+</h1>
+<h4 align="center"><em><span style="font-size:18pt">Large Language Models with Parameter-Efficient Federated Finetuning in the Presence of Heterogeneous Instructions</span></em></h4>
+
+<p align="center">
+  <a href="#Overview">Overview</a> •
+  <a href="#Installation">Installation</a> •
+  <a href="#Data_Preparation">Data_Preparation</a> •
+  <a href="#Federated_Finetuning">Federated_Finetuning</a> •
+  <a href="#Inference">Inference</a> •
+  <a href="#Citation">Citation</a> 
+</p>
+
+
+
+
+## Overview
+
+Recent advancements in fine-tuning large language models (LLMs) have leveraged instructions created by humans or APIs (such as ChatGPT and GPT-4) to revolutionize NLP research and industry applications. However, the collection of instructions from a wide array of individuals presents challenges in privacy and heterogeneity. Federated Learning, a well-studied and well-developed learning approach, provides a solution to addresses these challenges and paves the way for designing personalized LLMs tailored to individual users.
+
+This repository offers a foundational framework for exploring federated fine-tuning of LLMs using heterogeneous instructions across diverse categories. The framework is designed for ease of use, adaptability, and scalability to accommodate large datasets. Additionally, it facilitates seamless integration of novel algorithms and configurations, making it a convenient tool for both researchers and practitioners in the NLP community.
+
+
+
+## Installation 
+
+The code requires some dependencies (Python=3.8)  as specified in `requirements.txt`. Please follow the relevant libraries to install or run:
+```bash
+pip install -r requirements.txt
+```
+If `bitsandbytes` doesn't work, [install it from source](https://github.com/TimDettmers/bitsandbytes/blob/main/compile_from_source.md). Windows users can follow [these instructions](https://github.com/tloen/alpaca-lora/issues/17).
+
+
+## Data_Preparation
+
+Prior to commencing the federated fine-tuning, make sure to create a data file for each individual client.
+```bash
+num_client=10 # The number of clients
+diff_quantity=0 # Whether clients have different amounts of data
+python clients_datasets.py $num_client $diff_quantity
+```
+Running this command will save the data files in the folder `./data/str(num_client)`. The data file `new-databricks-dolly-15k.json` for generating each client's local dataset is the first version of `databricks-dolly-15k` , which is a corpus of more than 15,000 records with 8 categeries generated by thousands of [Databricks Lab](https://www.databricks.com/learn/labs) employees. Please refer to their official repository [dolly](https://github.com/databrickslabs/dolly) for the latest version of data.
+
+### Categories distribution and Heteogeneity
+The first version of `databricks-dolly-15k` contains 8 Categories, with the distribution of each category shown in the following figure.
+
+<p align="center">
+  <img src="assets/pie_chart_viridis_style.png" width="100%">
+</p>
+
+
+
+The following table presents an illustrative depiction of the category distributions among each client, serving to exemplify the diverse nature of clients' instructions
+
+|          | Open_qa | General_qa | Classification | Closed_qa | Brainstorming | Information_extraction | Summarization | Creative_writing |
+|----------|---------|------------|----------------|-----------|---------------|------------------------|---------------|------------------|
+| Client 0 | 0       | 0          | **149**        | **598**   | 0             | 0                      | **746**       | 0                |
+| Client 1 | **747** | 0          | **747**        | 0         | 0             | 0                      | 0             | 0                |
+| Client 2 | **377** | **747**    | 0              | 0         | 0             | **370**                | 0             | 0                |
+| Client 3 | **985** | 0          | 0              | 0         | 0             | 0                      | **507**       | 0                |
+| Client 4 | 0       | 0          | 0              | **747**   | 0             | **747**                | 0             | 0                |
+| Client 5 | **746** | **747**    | 0              | 0         | 0             | 0                      | 0             | 0                |
+| Client 6 | 0       | **362**    | 0              | 0         | **747**       | **385**                | 0             | 0                |
+| Client 7 | **746** | 0          | **483**        | 0         | **264**       | 0                      | 0             | 0                |
+| Client 8 | 0       | **325**    | 0              | **468**   | 0             | 0                      | 0             | **701**          |
+| Client 9 | 0       | 0          | **747**        | 0         | **747**       | 0                      | 0             | 0                |
+
+
+
+### Use your own data
+
+You can simply modify `clients_datasets.py` to load your own  dataset for federated training.
+
+
+## Federated_Finetuning
+
+To fully leverage the computational resources of each participating client, our lightweight Federated Learning framework employs the well-established parameter-efficient method, [LoRA](https://github.com/microsoft/LoRA), for conducting local training. The local training process is built upon the implementations of Hugging Face's [PEFT](https://github.com/huggingface/peft), Tim Dettmers' [bitsandbytes](https://github.com/TimDettmers/bitsandbytes), and the [Alpaca-lora](https://github.com/tloen/alpaca-lora), enabling the training to be completed within hours on a single NVIDIA TITAN RTX.
+
+
+Example usage:
+```bash
+python main.py --global_model 'chavinlo/alpaca-native'\
+      --data_path  "./data" \
+      --output_dir  './lora-shepherd-7b/'\
+      --train_on_inputs \
+      --group_by_length
+```
+Within the `main.py` file, the GeneralClient is a Python class serves as a representation of the local client and encompasses five distinct sections that facilitate local training: "prepare_local_dataset," "build_local_trainer," "initiate_local_training," "train," and "terminate_local_training." Each of these sections is easy to comprehend and can be easily customized by adding your own functions to meet specific requirements.
+
+We can also tweak the hyperparameters:
+```bash
+python main.py --global_model 'chavinlo/alpaca-native'\
+      --data_path  "./data" \
+      --output_dir  './lora-shepherd-7b/'\
+      --num_communication_rounds 10 \
+      --num_clients  100 \
+      --client_selection_frac 0.05 \
+      --local_num_epochs  2 \
+      --local_batch_size  64 \
+      --local_micro_batch_size 32 \
+      --local_learning_rate 0.0003 \
+      --lora_r 8 \
+      --lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \
+      --train_on_inputs \
+      --group_by_length
+```
+
+Our framework supports numerous popular LLMs, such as [LLaMA](https://github.com/facebookresearch/llama), [Alpaca](https://github.com/tatsu-lab/stanford_alpaca), [Vicuna](https://vicuna.lmsys.org/), [Baize](https://github.com/project-baize/baize-chatbot), and others. We welcome any pull requests that adapt our code to support additional models or datasets.
+
+
+## Inference 
+
+The `GlobalModel_generate.py` file streamlines the inference process for the global model by utilizing a Gradio interface. This file loads the foundation model from the Hugging Face Model Hub and obtains the LoRA weights and configurations from the output directory.
+
+```bash
+python GlobalModel_generate.py \
+      --load_8bit \
+      --base_model 'chavinlo/alpaca-native' \
+      --lora_weights_path /output/path/to/lora_weights  \
+      --lora_config_path /output/path/to/lora_config   
+      
+```
+
+
+## Citation
+
+Please cite this repo if you find our repository helpful for your research.
+```
+@misc{Shepherd,
+  author = {Jianyi Zhang, Martin Kuo, Ruiyi Zhang, Guoyin Wang, Saeed Vahidian, Yiran Chen },
+  title = {Shepherd: Large Language Models with Parameter-Efficient Federated Finetuning in the Presence of Heterogeneous Instructions},
+  year = {2023},
+  publisher = {GitHub},
+  journal = {GitHub repository},
+  howpublished = {\url{https://github.com/JayZhang42/FederatedGPT-Shepherd}},
+}
+```
+
+## Note!
+
+We are constantly working to enhance this framework by resolving bugs and extending its functionality and simulation capabilities. We welcome pull requests that adapt our code to support additional research goals, such as benchmarking of models and datasets, algorithmic enhancements, and hardware simulation.
+
+
+
+
+
+
+

BIN
assets/pie_chart_viridis_style.png


BIN
assets/shepherd&llamas.png


+ 85 - 0
clients_datasets.py

@@ -0,0 +1,85 @@
+import sys
+import pandas as pd
+import numpy as np
+import random
+import os
+import json
+import pdb
+
+num_clients = int(sys.argv[1])
+diff_quantity = int(sys.argv[2])
+
+np.random.seed(42)
+random.seed(42)
+
+# Divide the entire dataset into a training set and a test set.
+
+df = pd.read_json("new-databricks-dolly-15k.json", orient='records')
+sorted_df = df.sort_values(by=['category'])
+grouped = sorted_df.groupby('category')
+sampled_df = grouped.apply(lambda x: x.sample(n=10))
+sampled_df = sampled_df.reset_index(level=0, drop=True)
+remaining_df = sorted_df.drop(index=sampled_df.index)
+
+sampled_df = sampled_df.reset_index().drop('index', axis=1)
+remaining_df = remaining_df.reset_index().drop('index', axis=1)
+data_path = os.path.join("data", str(num_clients))
+
+os.makedirs(data_path,exist_ok=True)
+
+remaining_df_dic = remaining_df.to_dict(orient='records')
+with open(os.path.join(data_path, "global_training.json"), 'w') as outfile:
+    json.dump(remaining_df_dic, outfile)
+
+sampled_df_dic = sampled_df.to_dict(orient='records')
+with open(os.path.join(data_path, "global_test.json"), 'w') as outfile:
+    json.dump(sampled_df_dic, outfile)
+
+# Partition the global training data into smaller subsets for each client's local training dataset
+
+if diff_quantity:
+    min_size = 0
+    min_require_size = 40
+    alpha = 0.5
+
+    N = len(remaining_df)
+    net_dataidx_map = {}
+    category_uniques = remaining_df['category'].unique().tolist()
+    while min_size < min_require_size:
+
+        idx_partition = [[] for _ in range(num_clients)]
+        for k in range(len(category_uniques)):
+            category_rows_k = remaining_df.loc[remaining_df['category'] == category_uniques[k]]
+            category_rows_k_index = category_rows_k.index.values
+            np.random.shuffle(category_rows_k_index)
+            proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
+            proportions = np.array([p * (len(idx_j) < N / num_clients) for p, idx_j in zip(proportions, idx_partition)])
+            proportions = proportions / proportions.sum()
+            proportions = (np.cumsum(proportions) * len(category_rows_k_index)).astype(int)[:-1]
+            idx_partition = [idx_j + idx.tolist() for idx_j, idx in
+                             zip(idx_partition, np.split(category_rows_k_index, proportions))]
+            min_size = min([len(idx_j) for idx_j in idx_partition])
+
+        print(min_size)
+
+
+else:
+    num_shards_per_clients = 2
+    remaining_df_index = remaining_df.index.values
+    shards = np.array_split(remaining_df_index, int(num_shards_per_clients * num_clients))
+    random.shuffle(shards)
+
+    shards = [shards[i:i + num_shards_per_clients] for i in range(0, len(shards), num_shards_per_clients)]
+    idx_partition = [np.concatenate(shards[n]).tolist() for n in range(num_clients)]
+
+
+for client_id, idx in enumerate(idx_partition):
+    print(
+        "\n Generating the local training dataset of Client_{}".format(client_id)
+    )
+    sub_remaining_df = remaining_df.loc[idx]
+    sub_remaining_df = sub_remaining_df.reset_index().drop('index', axis=1)
+    sub_remaining_df_dic = sub_remaining_df.to_dict(orient='records')
+
+    with open(os.path.join(data_path, "local_training_{}.json".format(client_id)), 'w') as outfile:
+        json.dump(sub_remaining_df_dic, outfile)

+ 5 - 0
fed_utils/__init__.py

@@ -0,0 +1,5 @@
+from .fed_optimizer import FedAvg
+from .client_selection import client_selection
+from .client import GeneralClient
+from .evaluation import global_evaluation
+from .other import other_function

+ 104 - 0
fed_utils/client.py

@@ -0,0 +1,104 @@
+import transformers
+import os
+from datasets import load_dataset
+import copy
+from collections import OrderedDict
+import torch
+from peft import (
+    get_peft_model_state_dict,
+    set_peft_model_state_dict,
+)
+
+
+class GeneralClient:
+    def __init__(self, client_id, model, data_path, output_dir):
+        self.client_id = client_id
+        self.model = model
+        self.local_data_path = os.path.join(data_path, "local_training_{}.json".format(self.client_id))
+        self.local_data = load_dataset("json", data_files=self.local_data_path)
+        self.output_dir = output_dir
+        self.local_output_dir = os.path.join(self.output_dir, "trainer_saved", "local_output_{}".format(self.client_id))
+
+    def preprare_local_dataset(self, generate_and_tokenize_prompt, local_val_set_size):
+        if local_val_set_size > 0:
+            local_train_val = self.local_data["train"].train_test_split(
+                test_size=local_val_set_size, shuffle=True, seed=42
+            )
+            self.local_train_dataset = (
+                local_train_val["train"].shuffle().map(generate_and_tokenize_prompt)
+            )
+            self.local_eval_dataset = (
+                local_train_val["test"].shuffle().map(generate_and_tokenize_prompt)
+            )
+        else:
+            self.local_train_dataset = self.local_data["train"].shuffle().map(generate_and_tokenize_prompt)
+            self.local_eval_dataset = None
+        self.local_val_set_size = local_val_set_size
+
+    def build_local_trainer(self,
+                            tokenizer,
+                            local_micro_batch_size,
+                            gradient_accumulation_steps,
+                            local_num_epochs,
+                            local_learning_rate,
+                            group_by_length,
+                            ddp):
+        self.train_args = transformers.TrainingArguments(
+            per_device_train_batch_size=local_micro_batch_size,
+            gradient_accumulation_steps=gradient_accumulation_steps,
+            warmup_steps=0,
+            num_train_epochs=local_num_epochs,
+            learning_rate=local_learning_rate,
+            fp16=True,
+            logging_steps=1,
+            optim="adamw_torch",
+            evaluation_strategy="steps" if self.local_val_set_size > 0 else "no",
+            save_strategy="steps",
+            eval_steps=200 if self.local_val_set_size > 0 else None,
+            save_steps=200,
+            output_dir=self.local_output_dir,
+            save_total_limit=1,
+            load_best_model_at_end=True if self.local_val_set_size > 0 else False,
+            ddp_find_unused_parameters=False if ddp else None,
+            group_by_length=group_by_length,
+            dataloader_drop_last=False
+        )
+        self.local_trainer = transformers.Trainer(model=self.model,
+                                                  train_dataset=self.local_train_dataset,
+                                                  eval_dataset=self.local_eval_dataset,
+                                                  args=self.train_args,
+                                                  data_collator=transformers.DataCollatorForSeq2Seq(
+                                                      tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
+                                                  ),
+                                                  )
+
+    def initiate_local_training(self):
+        self.model.config.use_cache = False
+        self.params_dict_old = copy.deepcopy(
+            OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if
+                        "default" in name))
+        self.params_dict_new = OrderedDict((name, param.detach()) for name, param in self.model.named_parameters() if
+                                           "default" in name)
+        self.model.state_dict = (
+            lambda instance, *_, **__: get_peft_model_state_dict(
+                instance, self.params_dict_new, "default"
+            )
+        ).__get__(self.model, type(self.model))
+
+    def train(self):
+        self.local_trainer.train()
+
+    def terminate_local_training(self, epoch, local_dataset_len_dict, previously_selected_clients_set):
+
+        local_dataset_len_dict[self.client_id] = len(self.local_train_dataset)
+        new_adapter_weight = self.model.state_dict()
+        single_output_dir = os.path.join(self.output_dir, str(epoch), "local_output_{}".format(self.client_id))
+        os.makedirs(single_output_dir, exist_ok=True)
+        torch.save(new_adapter_weight, single_output_dir + "/pytorch_model.bin")
+
+        older_adapter_weight = get_peft_model_state_dict(self.model, self.params_dict_old, "default")
+        set_peft_model_state_dict(self.model, older_adapter_weight, "default")
+        previously_selected_clients_set = previously_selected_clients_set | set({self.client_id})
+        last_client_id = self.client_id
+
+        return self.model, local_dataset_len_dict, previously_selected_clients_set, last_client_id

+ 10 - 0
fed_utils/client_selection.py

@@ -0,0 +1,10 @@
+import numpy as np
+
+
+def client_selection(num_clients, client_selection_frac, client_selection_strategy, other_info=None):
+    np.random.seed(other_info)
+    if client_selection_strategy == "random":
+        num_selected = max(int(client_selection_frac * num_clients), 1)
+        selected_clients_set = set(np.random.choice(np.arange(num_clients), num_selected, replace=False))
+
+    return selected_clients_set

+ 6 - 0
fed_utils/evaluation.py

@@ -0,0 +1,6 @@
+
+
+def global_evaluation():
+
+    return print("CREATE THE NECESSARY EVALUATION ACCORDING TO YOUR REQUIREMENTS")
+

+ 29 - 0
fed_utils/fed_optimizer.py

@@ -0,0 +1,29 @@
+from peft import (
+    set_peft_model_state_dict,
+)
+import torch
+import os
+from torch.nn.functional import normalize
+
+
+def FedAvg(model, selected_clients_set, output_dir, local_dataset_len_dict, epoch):
+    weights_array = normalize(
+        torch.tensor([local_dataset_len_dict[client_id] for client_id in selected_clients_set],
+                     dtype=torch.float32),
+        p=1, dim=0)
+
+    for k, client_id in enumerate(selected_clients_set):
+        single_output_dir = os.path.join(output_dir, str(epoch), "local_output_{}".format(client_id),
+                                         "pytorch_model.bin")
+        single_weights = torch.load(single_output_dir)
+        if k == 0:
+            weighted_single_weights = {key: single_weights[key] * (weights_array[k]) for key in
+                                       single_weights.keys()}
+        else:
+            weighted_single_weights = {key: weighted_single_weights[key] + single_weights[key] * (weights_array[k])
+                                       for key in
+                                       single_weights.keys()}
+
+    set_peft_model_state_dict(model, weighted_single_weights, "default")
+
+    return model

+ 6 - 0
fed_utils/other.py

@@ -0,0 +1,6 @@
+def other_function():
+
+    return print("design the other functions you need")
+
+
+

+ 212 - 0
main.py

@@ -0,0 +1,212 @@
+import os
+from typing import List
+from tqdm import tqdm
+import fire
+import torch
+from transformers import LlamaTokenizer, LlamaForCausalLM
+from peft import (
+    LoraConfig,
+    get_peft_model,
+    prepare_model_for_int8_training,
+)
+from fed_utils import FedAvg, client_selection, global_evaluation, GeneralClient
+import datasets
+from utils.prompter import Prompter
+
+datasets.utils.logging.set_verbosity_error()
+
+
+def fl_finetune(
+        # model/data params
+        global_model: str = '',
+        data_path: str = './data',
+        output_dir: str = './lora-shepherd/',
+        # FL hyperparamas
+        client_selection_strategy: str = 'random',
+        client_selection_frac: float = 0.1,
+        num_communication_rounds: int = 50,
+        num_clients: int = 10,
+        # Local training hyperparams
+        local_batch_size: int = 64,  # 64,
+        local_micro_batch_size: int = 8,
+        local_num_epochs: int = 10,
+        local_learning_rate: float = 3e-4,
+        local_val_set_size: int = 0,
+        local_save_steps: int = 3,
+        cutoff_len: int = 512,
+        # LoRA hyperparams
+        lora_r: int = 16,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.05,
+        lora_target_modules: List[str] = [
+            "q_proj",
+        ],
+        # llm hyperparams
+        train_on_inputs: bool = True,
+        group_by_length: bool = False,
+        resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
+        prompt_template_name: str = "alpaca",  # The prompt template to use, will default to alpaca.
+):
+    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
+        print(
+            f"Federated Finetuning LLM-LoRA with params:\n"
+            f"global_model: {global_model}\n"
+            f"data_path: {data_path}\n"
+            f"output_dir: {output_dir}\n"
+            f"client_selection_strategy: {client_selection_strategy}\n"
+            f"client_selection_frac: {client_selection_frac}\n"
+            f"num_communication_rounds: {num_communication_rounds}\n"
+            f"num_clients: {num_clients}\n"
+            f"local_batch_size: {local_batch_size}\n"
+            f"local_micro_batch_size: {local_micro_batch_size}\n"
+            f"local_num_epochs: {local_num_epochs}\n"
+            f"local_learning_rate: {local_learning_rate}\n"
+            f"local_val_set_size: {local_val_set_size}\n"
+            f"local_save_steps: {local_save_steps}\n"
+            f"cutoff_len: {cutoff_len}\n"
+            f"lora_r: {lora_r}\n"
+            f"lora_alpha: {lora_alpha}\n"
+            f"lora_dropout: {lora_dropout}\n"
+            f"lora_target_modules: {lora_target_modules}\n"
+            f"train_on_inputs: {train_on_inputs}\n"
+            f"group_by_length: {group_by_length}\n"
+            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
+            f"prompt template: {prompt_template_name}\n"
+        )
+    assert (
+        global_model
+    ), "Please specify a --global_model, e.g. --global_modell='decapoda-research/llama-7b-hf'"
+
+    data_path = os.path.join(data_path, str(num_clients))
+    assert (os.path.exists(data_path), "Please generate the data files for each client")
+
+    # set up the global model & toknizer
+    gradient_accumulation_steps = local_batch_size // local_micro_batch_size
+    prompter = Prompter(prompt_template_name)
+    device_map = "auto"
+    world_size = int(os.environ.get("WORLD_SIZE", 1))
+    ddp = world_size != 1
+    if ddp:
+        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
+        gradient_accumulation_steps = gradient_accumulation_steps // world_size
+
+    model = LlamaForCausalLM.from_pretrained(
+        global_model,
+        load_in_8bit=True,
+        torch_dtype=torch.float16,
+        device_map=device_map,
+    )
+
+    tokenizer = LlamaTokenizer.from_pretrained(global_model)
+    tokenizer.pad_token_id = (
+        0
+    )
+    tokenizer.padding_side = "left"
+
+    def tokenize(prompt, add_eos_token=True):
+        result = tokenizer(
+            prompt,
+            truncation=True,
+            max_length=cutoff_len,
+            padding=False,
+            return_tensors=None,
+        )
+        if (
+                result["input_ids"][-1] != tokenizer.eos_token_id
+                and len(result["input_ids"]) < cutoff_len
+                and add_eos_token
+        ):
+            result["input_ids"].append(tokenizer.eos_token_id)
+            result["attention_mask"].append(1)
+
+        result["labels"] = result["input_ids"].copy()
+
+        return result
+
+    def generate_and_tokenize_prompt(data_point):
+        full_prompt = prompter.generate_prompt(
+            data_point["instruction"],
+            data_point["context"],
+            data_point["response"],
+        )
+        tokenized_full_prompt = tokenize(full_prompt)
+        if not train_on_inputs:
+            user_prompt = prompter.generate_prompt(
+                data_point["instruction"], data_point["context"]
+            )
+            tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
+            user_prompt_len = len(tokenized_user_prompt["input_ids"])
+
+            tokenized_full_prompt["labels"] = [
+                                                  -100
+                                              ] * user_prompt_len + tokenized_full_prompt["labels"][
+                                                                    user_prompt_len:
+                                                                    ]  # could be sped up, probably
+        return tokenized_full_prompt
+
+    model = prepare_model_for_int8_training(model)
+    config = LoraConfig(
+        r=lora_r,
+        lora_alpha=lora_alpha,
+        target_modules=lora_target_modules,
+        lora_dropout=lora_dropout,
+        bias="none",
+        task_type="CAUSAL_LM",
+    )
+    model = get_peft_model(model, config)
+    if not ddp and torch.cuda.device_count() > 1:
+        model.is_parallelizable = True
+        model.model_parallel = True
+
+    print("The process of federated instruction-tuning has started..")
+    previously_selected_clients_set = set()
+    last_client_id = None
+    local_dataset_len_dict = dict()
+    output_dir = os.path.join(output_dir, str(num_clients))
+
+    for epoch in tqdm(range(num_communication_rounds)):
+
+        print("\nConducting the client selection")
+        selected_clients_set = client_selection(num_clients, client_selection_frac, client_selection_strategy,
+                                                other_info=epoch)
+
+        for client_id in selected_clients_set:
+            client = GeneralClient(client_id, model, data_path, output_dir)
+
+            print("\nPreparing the local dataset and trainer for Client_{}".format(client_id))
+            client.preprare_local_dataset(generate_and_tokenize_prompt, local_val_set_size)
+            client.build_local_trainer(tokenizer,
+                                       local_micro_batch_size,
+                                       gradient_accumulation_steps,
+                                       local_num_epochs,
+                                       local_learning_rate,
+                                       group_by_length,
+                                       ddp)
+
+            print("Initiating the local training of Client_{}".format(client_id))
+            client.initiate_local_training()
+
+            print("Local training starts ... ")
+            client.train()
+
+            print("\nTerminating the local training of Client_{}".format(client_id))
+            model, local_dataset_len_dict, previously_selected_clients_set, last_client_id = client.terminate_local_training(
+                epoch, local_dataset_len_dict, previously_selected_clients_set)
+            del client
+
+        print("Collecting the weights of clients and performing aggregation")
+        model = FedAvg(model,
+                       selected_clients_set,
+                       output_dir,
+                       local_dataset_len_dict,
+                       epoch,
+                       )
+        torch.save(model.state_dict(), os.path.join(output_dir, str(epoch), "adapter_model.bin"))
+        config.save_pretrained(output_dir)
+
+        # Please design the evaluation method based on your specific requirements in the fed_utils/evaluation.py file.
+        global_evaluation()
+
+
+if __name__ == "__main__":
+    fire.Fire(fl_finetune)

File diff suppressed because it is too large
+ 64 - 0
new-databricks-dolly-15k.json


+ 46 - 0
templates/README.md

@@ -0,0 +1,46 @@
+# Prompt templates
+
+This directory contains template styles for the prompts used to finetune LoRA models.
+
+## Format
+
+A template is described via a JSON file with the following keys:
+
+- `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
+- `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
+- `description`: A short description of the template, with possible use cases.
+- `response_split`: The text to use as separator when cutting real response from the model output.
+
+No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.
+
+## Example template
+
+The default template, used unless otherwise specified, is `alpaca.json`
+
+```json
+{
+    "description": "Template used by Alpaca-LoRA.",
+    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
+    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
+    "response_split": "### Response:"    
+}
+
+```
+
+## Current templates
+
+### alpaca
+
+Default template used for generic LoRA fine tunes so far.
+
+### alpaca_legacy
+
+Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.
+
+### alpaca_short
+
+A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.
+
+### vigogne
+
+The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.

+ 6 - 0
templates/alpaca.json

@@ -0,0 +1,6 @@
+{
+    "description": "Template used by Alpaca-LoRA.",
+    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
+    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
+    "response_split": "### Response:"    
+}

+ 6 - 0
templates/alpaca_legacy.json

@@ -0,0 +1,6 @@
+{
+    "description": "Legacy template, used by Original Alpaca repository.",
+    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
+    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
+    "response_split": "### Response:"    
+}

+ 6 - 0
templates/alpaca_short.json

@@ -0,0 +1,6 @@
+{
+    "description": "A shorter template to experiment with.",
+    "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
+    "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
+    "response_split": "### Response:"    
+}

+ 6 - 0
templates/vigogne.json

@@ -0,0 +1,6 @@
+{
+    "description": "French template, used by Vigogne for finetuning.",
+    "prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
+    "prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
+    "response_split": "### Réponse:"
+}

+ 13 - 0
utils/README.md

@@ -0,0 +1,13 @@
+# Directory for helpers modules
+
+## prompter.py
+
+Prompter class, a template manager.
+
+`from utils.prompter import Prompter`
+
+## callbacks.py
+
+Helpers to support streaming generate output.
+
+`from utils.callbacks import Iteratorize, Stream`

+ 0 - 0
utils/__init__.py


+ 75 - 0
utils/callbacks.py

@@ -0,0 +1,75 @@
+"""
+Helpers to support streaming generate output.
+Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
+"""
+
+import gc
+import traceback
+from queue import Queue
+from threading import Thread
+
+import torch
+import transformers
+
+
+class Stream(transformers.StoppingCriteria):
+    def __init__(self, callback_func=None):
+        self.callback_func = callback_func
+
+    def __call__(self, input_ids, scores) -> bool:
+        if self.callback_func is not None:
+            self.callback_func(input_ids[0])
+        return False
+
+
+class Iteratorize:
+
+    """
+    Transforms a function that takes a callback
+    into a lazy iterator (generator).
+    """
+
+    def __init__(self, func, kwargs={}, callback=None):
+        self.mfunc = func
+        self.c_callback = callback
+        self.q = Queue()
+        self.sentinel = object()
+        self.kwargs = kwargs
+        self.stop_now = False
+
+        def _callback(val):
+            if self.stop_now:
+                raise ValueError
+            self.q.put(val)
+
+        def gentask():
+            try:
+                ret = self.mfunc(callback=_callback, **self.kwargs)
+            except ValueError:
+                pass
+            except:
+                traceback.print_exc()
+                pass
+
+            self.q.put(self.sentinel)
+            if self.c_callback:
+                self.c_callback(ret)
+
+        self.thread = Thread(target=gentask)
+        self.thread.start()
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        obj = self.q.get(True, None)
+        if obj is self.sentinel:
+            raise StopIteration
+        else:
+            return obj
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.stop_now = True

+ 51 - 0
utils/prompter.py

@@ -0,0 +1,51 @@
+"""
+A dedicated helper to manage templates and prompt building.
+"""
+
+import json
+import os.path as osp
+from typing import Union
+
+
+class Prompter(object):
+    __slots__ = ("template", "_verbose")
+
+    def __init__(self, template_name: str = "", verbose: bool = False):
+        self._verbose = verbose
+        if not template_name:
+            # Enforce the default here, so the constructor can be called with '' and will not break.
+            template_name = "alpaca"
+        file_name = osp.join("templates", f"{template_name}.json")
+        if not osp.exists(file_name):
+            raise ValueError(f"Can't read {file_name}")
+        with open(file_name) as fp:
+            self.template = json.load(fp)
+        if self._verbose:
+            print(
+                f"Using prompt template {template_name}: {self.template['description']}"
+            )
+
+    def generate_prompt(
+        self,
+        instruction: str,
+        input: Union[None, str] = None,
+        label: Union[None, str] = None,
+    ) -> str:
+        # returns the full prompt from instruction and optional input
+        # if a label (=response, =output) is provided, it's also appended.
+        if input:
+            res = self.template["prompt_input"].format(
+                instruction=instruction, input=input
+            )
+        else:
+            res = self.template["prompt_no_input"].format(
+                instruction=instruction
+            )
+        if label:
+            res = f"{res}{label}"
+        if self._verbose:
+            print(res)
+        return res
+
+    def get_response(self, output: str) -> str:
+        return output.split(self.template["response_split"])[1].strip()

Some files were not shown because too many files changed in this diff