main.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import os
  2. from typing import List
  3. from tqdm import tqdm
  4. import fire
  5. import torch
  6. from transformers import LlamaTokenizer, LlamaForCausalLM
  7. from peft import (
  8. LoraConfig,
  9. get_peft_model,
  10. prepare_model_for_int8_training,
  11. )
  12. from fed_utils import FedAvg, client_selection, global_evaluation, GeneralClient
  13. import datasets
  14. from utils.prompter import Prompter
  15. datasets.utils.logging.set_verbosity_error()
  16. # 定义 federated learning 微调函数
  17. def fl_finetune(
  18. # model/data params
  19. global_model: str = '',
  20. data_path: str = './data',
  21. output_dir: str = './lora-shepherd/',
  22. # FL hyperparamas
  23. client_selection_strategy: str = 'random',
  24. client_selection_frac: float = 0.1,
  25. num_communication_rounds: int = 50,
  26. num_clients: int = 10,
  27. # Local training hyperparams
  28. local_batch_size: int = 64, # 64,
  29. local_micro_batch_size: int = 8,
  30. local_num_epochs: int = 10,
  31. local_learning_rate: float = 3e-4,
  32. local_val_set_size: int = 0,
  33. local_save_steps: int = 3,
  34. cutoff_len: int = 512,
  35. # LoRA hyperparams
  36. lora_r: int = 16,
  37. lora_alpha: int = 16,
  38. lora_dropout: float = 0.05,
  39. lora_target_modules: List[str] = [
  40. "q_proj",
  41. ],
  42. # llm hyperparams
  43. train_on_inputs: bool = True,
  44. group_by_length: bool = False,
  45. resume_from_checkpoint: str = None, # either training checkpoint or final adapter
  46. prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
  47. ):
  48. if int(os.environ.get("LOCAL_RANK", 0)) == 0:
  49. print(
  50. f"Federated Finetuning LLM-LoRA with params:\n"
  51. f"global_model: {global_model}\n"
  52. f"data_path: {data_path}\n"
  53. f"output_dir: {output_dir}\n"
  54. f"client_selection_strategy: {client_selection_strategy}\n"
  55. f"client_selection_frac: {client_selection_frac}\n"
  56. f"num_communication_rounds: {num_communication_rounds}\n"
  57. f"num_clients: {num_clients}\n"
  58. f"local_batch_size: {local_batch_size}\n"
  59. f"local_micro_batch_size: {local_micro_batch_size}\n"
  60. f"local_num_epochs: {local_num_epochs}\n"
  61. f"local_learning_rate: {local_learning_rate}\n"
  62. f"local_val_set_size: {local_val_set_size}\n"
  63. f"local_save_steps: {local_save_steps}\n"
  64. f"cutoff_len: {cutoff_len}\n"
  65. f"lora_r: {lora_r}\n"
  66. f"lora_alpha: {lora_alpha}\n"
  67. f"lora_dropout: {lora_dropout}\n"
  68. f"lora_target_modules: {lora_target_modules}\n"
  69. f"train_on_inputs: {train_on_inputs}\n"
  70. f"group_by_length: {group_by_length}\n"
  71. f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
  72. f"prompt template: {prompt_template_name}\n"
  73. )
  74. assert (
  75. global_model
  76. ), "Please specify a --global_model, e.g. --global_modell='decapoda-research/llama-7b-hf'"
  77. data_path = os.path.join(data_path, str(num_clients))
  78. assert (os.path.exists(data_path), "Please generate the data files for each client")
  79. # set up the global model & toknizer
  80. # 设置全局模型和分词器
  81. gradient_accumulation_steps = local_batch_size // local_micro_batch_size
  82. prompter = Prompter(prompt_template_name)
  83. device_map = "auto"
  84. world_size = int(os.environ.get("WORLD_SIZE", 1))
  85. ddp = world_size != 1
  86. if ddp:
  87. device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  88. gradient_accumulation_steps = gradient_accumulation_steps // world_size
  89. # 从预训练模型加载模型
  90. model = LlamaForCausalLM.from_pretrained(
  91. global_model,
  92. load_in_8bit=True,
  93. torch_dtype=torch.float16,
  94. device_map=device_map,
  95. )
  96. # 从预训练模型加载分词器
  97. tokenizer = LlamaTokenizer.from_pretrained(global_model)
  98. tokenizer.pad_token_id = (
  99. 0
  100. )
  101. tokenizer.padding_side = "left"
  102. def tokenize(prompt, add_eos_token=True):
  103. result = tokenizer(
  104. prompt,
  105. truncation=True,
  106. max_length=cutoff_len,
  107. padding=False,
  108. return_tensors=None,
  109. )
  110. if (
  111. result["input_ids"][-1] != tokenizer.eos_token_id
  112. and len(result["input_ids"]) < cutoff_len
  113. and add_eos_token
  114. ):
  115. result["input_ids"].append(tokenizer.eos_token_id)
  116. result["attention_mask"].append(1)
  117. result["labels"] = result["input_ids"].copy()
  118. return result
  119. def generate_and_tokenize_prompt(data_point):
  120. full_prompt = prompter.generate_prompt(
  121. data_point["instruction"],
  122. data_point["context"],
  123. data_point["response"],
  124. )
  125. tokenized_full_prompt = tokenize(full_prompt)
  126. if not train_on_inputs:
  127. user_prompt = prompter.generate_prompt(
  128. data_point["instruction"], data_point["context"]
  129. )
  130. tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
  131. user_prompt_len = len(tokenized_user_prompt["input_ids"])
  132. tokenized_full_prompt["labels"] = [
  133. -100
  134. ] * user_prompt_len + tokenized_full_prompt["labels"][
  135. user_prompt_len:
  136. ] # could be sped up, probably
  137. return tokenized_full_prompt
  138. # 准备模型进行 int8 训练
  139. model = prepare_model_for_int8_training(model)
  140. # 设置 LoRA 配置
  141. config = LoraConfig(
  142. r=lora_r,
  143. lora_alpha=lora_alpha,
  144. target_modules=lora_target_modules,
  145. lora_dropout=lora_dropout,
  146. bias="none",
  147. task_type="CAUSAL_LM",
  148. )
  149. # 获取 peft 模型
  150. model = get_peft_model(model, config)
  151. # 判断是否使用数据并行
  152. if not ddp and torch.cuda.device_count() > 1:
  153. model.is_parallelizable = True
  154. model.model_parallel = True
  155. # 开始联邦训练过程
  156. print("The process of federated instruction-tuning has started..")
  157. previously_selected_clients_set = set()
  158. last_client_id = None
  159. local_dataset_len_dict = dict()
  160. output_dir = os.path.join(output_dir, str(num_clients))
  161. # 进行多轮联邦训练
  162. for epoch in tqdm(range(num_communication_rounds)):
  163. print("\nConducting the client selection")
  164. selected_clients_set = client_selection(num_clients, client_selection_frac, client_selection_strategy,
  165. other_info=epoch)
  166. for client_id in selected_clients_set:
  167. client = GeneralClient(client_id, model, data_path, output_dir)
  168. print("\nPreparing the local dataset and trainer for Client_{}".format(client_id))
  169. client.preprare_local_dataset(generate_and_tokenize_prompt, local_val_set_size)
  170. client.build_local_trainer(tokenizer,
  171. local_micro_batch_size,
  172. gradient_accumulation_steps,
  173. local_num_epochs,
  174. local_learning_rate,
  175. group_by_length,
  176. ddp)
  177. print("Initiating the local training of Client_{}".format(client_id))
  178. client.initiate_local_training()
  179. print("Local training starts ... ")
  180. client.train()
  181. print("\nTerminating the local training of Client_{}".format(client_id))
  182. model, local_dataset_len_dict, previously_selected_clients_set, last_client_id = client.terminate_local_training(
  183. epoch, local_dataset_len_dict, previously_selected_clients_set)
  184. del client
  185. # 收集客户权重并进行聚合
  186. print("Collecting the weights of clients and performing aggregation")
  187. model = FedAvg(model,
  188. selected_clients_set,
  189. output_dir,
  190. local_dataset_len_dict,
  191. epoch,
  192. )
  193. # 保存模型状态
  194. torch.save(model.state_dict(), os.path.join(output_dir, str(epoch), "adapter_model.bin"))
  195. config.save_pretrained(output_dir)
  196. # Please design the evaluation method based on your specific requirements in the fed_utils/evaluation.py file.
  197. global_evaluation()
  198. if __name__ == "__main__":
  199. fire.Fire(fl_finetune)