import os from typing import List from tqdm import tqdm import fire import torch from transformers import LlamaTokenizer, LlamaForCausalLM from peft import ( LoraConfig, get_peft_model, prepare_model_for_int8_training, ) from fed_utils import FedAvg, client_selection, global_evaluation, GeneralClient import datasets from utils.prompter import Prompter datasets.utils.logging.set_verbosity_error() # 定义 federated learning 微调函数 def fl_finetune( # model/data params global_model: str = '', data_path: str = './data', output_dir: str = './lora-shepherd/', # FL hyperparamas client_selection_strategy: str = 'random', client_selection_frac: float = 0.1, num_communication_rounds: int = 50, num_clients: int = 10, # Local training hyperparams local_batch_size: int = 64, # 64, local_micro_batch_size: int = 8, local_num_epochs: int = 10, local_learning_rate: float = 3e-4, local_val_set_size: int = 0, local_save_steps: int = 3, cutoff_len: int = 512, # LoRA hyperparams lora_r: int = 16, lora_alpha: int = 16, lora_dropout: float = 0.05, lora_target_modules: List[str] = [ "q_proj", ], # llm hyperparams train_on_inputs: bool = True, group_by_length: bool = False, resume_from_checkpoint: str = None, # either training checkpoint or final adapter prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. ): if int(os.environ.get("LOCAL_RANK", 0)) == 0: print( f"Federated Finetuning LLM-LoRA with params:\n" f"global_model: {global_model}\n" f"data_path: {data_path}\n" f"output_dir: {output_dir}\n" f"client_selection_strategy: {client_selection_strategy}\n" f"client_selection_frac: {client_selection_frac}\n" f"num_communication_rounds: {num_communication_rounds}\n" f"num_clients: {num_clients}\n" f"local_batch_size: {local_batch_size}\n" f"local_micro_batch_size: {local_micro_batch_size}\n" f"local_num_epochs: {local_num_epochs}\n" f"local_learning_rate: {local_learning_rate}\n" f"local_val_set_size: {local_val_set_size}\n" f"local_save_steps: {local_save_steps}\n" f"cutoff_len: {cutoff_len}\n" f"lora_r: {lora_r}\n" f"lora_alpha: {lora_alpha}\n" f"lora_dropout: {lora_dropout}\n" f"lora_target_modules: {lora_target_modules}\n" f"train_on_inputs: {train_on_inputs}\n" f"group_by_length: {group_by_length}\n" f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" f"prompt template: {prompt_template_name}\n" ) assert ( global_model ), "Please specify a --global_model, e.g. --global_modell='decapoda-research/llama-7b-hf'" data_path = os.path.join(data_path, str(num_clients)) assert (os.path.exists(data_path), "Please generate the data files for each client") # set up the global model & toknizer # 设置全局模型和分词器 gradient_accumulation_steps = local_batch_size // local_micro_batch_size prompter = Prompter(prompt_template_name) device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 if ddp: device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} gradient_accumulation_steps = gradient_accumulation_steps // world_size # 从预训练模型加载模型 model = LlamaForCausalLM.from_pretrained( global_model, load_in_8bit=True, torch_dtype=torch.float16, device_map=device_map, ) # 从预训练模型加载分词器 tokenizer = LlamaTokenizer.from_pretrained(global_model) tokenizer.pad_token_id = ( 0 ) tokenizer.padding_side = "left" def tokenize(prompt, add_eos_token=True): result = tokenizer( prompt, truncation=True, max_length=cutoff_len, padding=False, return_tensors=None, ) if ( result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < cutoff_len and add_eos_token ): result["input_ids"].append(tokenizer.eos_token_id) result["attention_mask"].append(1) result["labels"] = result["input_ids"].copy() return result def generate_and_tokenize_prompt(data_point): full_prompt = prompter.generate_prompt( data_point["instruction"], data_point["context"], data_point["response"], ) tokenized_full_prompt = tokenize(full_prompt) if not train_on_inputs: user_prompt = prompter.generate_prompt( data_point["instruction"], data_point["context"] ) tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) user_prompt_len = len(tokenized_user_prompt["input_ids"]) tokenized_full_prompt["labels"] = [ -100 ] * user_prompt_len + tokenized_full_prompt["labels"][ user_prompt_len: ] # could be sped up, probably return tokenized_full_prompt # 准备模型进行 int8 训练 model = prepare_model_for_int8_training(model) # 设置 LoRA 配置 config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, target_modules=lora_target_modules, lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM", ) # 获取 peft 模型 model = get_peft_model(model, config) # 判断是否使用数据并行 if not ddp and torch.cuda.device_count() > 1: model.is_parallelizable = True model.model_parallel = True # 开始联邦训练过程 print("The process of federated instruction-tuning has started..") previously_selected_clients_set = set() last_client_id = None local_dataset_len_dict = dict() output_dir = os.path.join(output_dir, str(num_clients)) # 进行多轮联邦训练 for epoch in tqdm(range(num_communication_rounds)): print("\nConducting the client selection") selected_clients_set = client_selection(num_clients, client_selection_frac, client_selection_strategy, other_info=epoch) for client_id in selected_clients_set: client = GeneralClient(client_id, model, data_path, output_dir) print("\nPreparing the local dataset and trainer for Client_{}".format(client_id)) client.preprare_local_dataset(generate_and_tokenize_prompt, local_val_set_size) client.build_local_trainer(tokenizer, local_micro_batch_size, gradient_accumulation_steps, local_num_epochs, local_learning_rate, group_by_length, ddp) print("Initiating the local training of Client_{}".format(client_id)) client.initiate_local_training() print("Local training starts ... ") client.train() print("\nTerminating the local training of Client_{}".format(client_id)) model, local_dataset_len_dict, previously_selected_clients_set, last_client_id = client.terminate_local_training( epoch, local_dataset_len_dict, previously_selected_clients_set) del client # 收集客户权重并进行聚合 print("Collecting the weights of clients and performing aggregation") model = FedAvg(model, selected_clients_set, output_dir, local_dataset_len_dict, epoch, ) # 保存模型状态 torch.save(model.state_dict(), os.path.join(output_dir, str(epoch), "adapter_model.bin")) config.save_pretrained(output_dir) # Please design the evaluation method based on your specific requirements in the fed_utils/evaluation.py file. global_evaluation() if __name__ == "__main__": fire.Fire(fl_finetune)