123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- import os
- from typing import List
- from tqdm import tqdm
- import fire
- import torch
- from transformers import LlamaTokenizer, LlamaForCausalLM
- from peft import (
- LoraConfig,
- get_peft_model,
- prepare_model_for_int8_training,
- )
- from fed_utils import FedAvg, client_selection, global_evaluation, GeneralClient
- import datasets
- from utils.prompter import Prompter
- datasets.utils.logging.set_verbosity_error()
- # 定义 federated learning 微调函数
- def fl_finetune(
- # model/data params
- global_model: str = '',
- data_path: str = './data',
- output_dir: str = './lora-shepherd/',
- # FL hyperparamas
- client_selection_strategy: str = 'random',
- client_selection_frac: float = 0.1,
- num_communication_rounds: int = 50,
- num_clients: int = 10,
- # Local training hyperparams
- local_batch_size: int = 64, # 64,
- local_micro_batch_size: int = 8,
- local_num_epochs: int = 10,
- local_learning_rate: float = 3e-4,
- local_val_set_size: int = 0,
- local_save_steps: int = 3,
- cutoff_len: int = 512,
- # LoRA hyperparams
- lora_r: int = 16,
- lora_alpha: int = 16,
- lora_dropout: float = 0.05,
- lora_target_modules: List[str] = [
- "q_proj",
- ],
- # llm hyperparams
- train_on_inputs: bool = True,
- group_by_length: bool = False,
- resume_from_checkpoint: str = None, # either training checkpoint or final adapter
- prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
- ):
- if int(os.environ.get("LOCAL_RANK", 0)) == 0:
- print(
- f"Federated Finetuning LLM-LoRA with params:\n"
- f"global_model: {global_model}\n"
- f"data_path: {data_path}\n"
- f"output_dir: {output_dir}\n"
- f"client_selection_strategy: {client_selection_strategy}\n"
- f"client_selection_frac: {client_selection_frac}\n"
- f"num_communication_rounds: {num_communication_rounds}\n"
- f"num_clients: {num_clients}\n"
- f"local_batch_size: {local_batch_size}\n"
- f"local_micro_batch_size: {local_micro_batch_size}\n"
- f"local_num_epochs: {local_num_epochs}\n"
- f"local_learning_rate: {local_learning_rate}\n"
- f"local_val_set_size: {local_val_set_size}\n"
- f"local_save_steps: {local_save_steps}\n"
- f"cutoff_len: {cutoff_len}\n"
- f"lora_r: {lora_r}\n"
- f"lora_alpha: {lora_alpha}\n"
- f"lora_dropout: {lora_dropout}\n"
- f"lora_target_modules: {lora_target_modules}\n"
- f"train_on_inputs: {train_on_inputs}\n"
- f"group_by_length: {group_by_length}\n"
- f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
- f"prompt template: {prompt_template_name}\n"
- )
- assert (
- global_model
- ), "Please specify a --global_model, e.g. --global_modell='decapoda-research/llama-7b-hf'"
- data_path = os.path.join(data_path, str(num_clients))
- assert (os.path.exists(data_path), "Please generate the data files for each client")
- # set up the global model & toknizer
- # 设置全局模型和分词器
- gradient_accumulation_steps = local_batch_size // local_micro_batch_size
- prompter = Prompter(prompt_template_name)
- device_map = "auto"
- world_size = int(os.environ.get("WORLD_SIZE", 1))
- ddp = world_size != 1
- if ddp:
- device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
- gradient_accumulation_steps = gradient_accumulation_steps // world_size
- # 从预训练模型加载模型
- model = LlamaForCausalLM.from_pretrained(
- global_model,
- load_in_8bit=True,
- torch_dtype=torch.float16,
- device_map=device_map,
- )
- # 从预训练模型加载分词器
- tokenizer = LlamaTokenizer.from_pretrained(global_model)
- tokenizer.pad_token_id = (
- 0
- )
- tokenizer.padding_side = "left"
- def tokenize(prompt, add_eos_token=True):
- result = tokenizer(
- prompt,
- truncation=True,
- max_length=cutoff_len,
- padding=False,
- return_tensors=None,
- )
- if (
- result["input_ids"][-1] != tokenizer.eos_token_id
- and len(result["input_ids"]) < cutoff_len
- and add_eos_token
- ):
- result["input_ids"].append(tokenizer.eos_token_id)
- result["attention_mask"].append(1)
- result["labels"] = result["input_ids"].copy()
- return result
- def generate_and_tokenize_prompt(data_point):
- full_prompt = prompter.generate_prompt(
- data_point["instruction"],
- data_point["context"],
- data_point["response"],
- )
- tokenized_full_prompt = tokenize(full_prompt)
- if not train_on_inputs:
- user_prompt = prompter.generate_prompt(
- data_point["instruction"], data_point["context"]
- )
- tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
- user_prompt_len = len(tokenized_user_prompt["input_ids"])
- tokenized_full_prompt["labels"] = [
- -100
- ] * user_prompt_len + tokenized_full_prompt["labels"][
- user_prompt_len:
- ] # could be sped up, probably
- return tokenized_full_prompt
- # 准备模型进行 int8 训练
- model = prepare_model_for_int8_training(model)
- # 设置 LoRA 配置
- config = LoraConfig(
- r=lora_r,
- lora_alpha=lora_alpha,
- target_modules=lora_target_modules,
- lora_dropout=lora_dropout,
- bias="none",
- task_type="CAUSAL_LM",
- )
- # 获取 peft 模型
- model = get_peft_model(model, config)
- # 判断是否使用数据并行
- if not ddp and torch.cuda.device_count() > 1:
- model.is_parallelizable = True
- model.model_parallel = True
- # 开始联邦训练过程
- print("The process of federated instruction-tuning has started..")
- previously_selected_clients_set = set()
- last_client_id = None
- local_dataset_len_dict = dict()
- output_dir = os.path.join(output_dir, str(num_clients))
- # 进行多轮联邦训练
- for epoch in tqdm(range(num_communication_rounds)):
- print("\nConducting the client selection")
- selected_clients_set = client_selection(num_clients, client_selection_frac, client_selection_strategy,
- other_info=epoch)
- for client_id in selected_clients_set:
- client = GeneralClient(client_id, model, data_path, output_dir)
- print("\nPreparing the local dataset and trainer for Client_{}".format(client_id))
- client.preprare_local_dataset(generate_and_tokenize_prompt, local_val_set_size)
- client.build_local_trainer(tokenizer,
- local_micro_batch_size,
- gradient_accumulation_steps,
- local_num_epochs,
- local_learning_rate,
- group_by_length,
- ddp)
- print("Initiating the local training of Client_{}".format(client_id))
- client.initiate_local_training()
- print("Local training starts ... ")
- client.train()
- print("\nTerminating the local training of Client_{}".format(client_id))
- model, local_dataset_len_dict, previously_selected_clients_set, last_client_id = client.terminate_local_training(
- epoch, local_dataset_len_dict, previously_selected_clients_set)
- del client
- # 收集客户权重并进行聚合
- print("Collecting the weights of clients and performing aggregation")
- model = FedAvg(model,
- selected_clients_set,
- output_dir,
- local_dataset_len_dict,
- epoch,
- )
- # 保存模型状态
- torch.save(model.state_dict(), os.path.join(output_dir, str(epoch), "adapter_model.bin"))
- config.save_pretrained(output_dir)
- # Please design the evaluation method based on your specific requirements in the fed_utils/evaluation.py file.
- global_evaluation()
- if __name__ == "__main__":
- fire.Fire(fl_finetune)
|