main.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import os
  2. from typing import List
  3. from tqdm import tqdm
  4. import fire
  5. import torch
  6. from transformers import LlamaTokenizer, LlamaForCausalLM
  7. from peft import (
  8. LoraConfig,
  9. get_peft_model,
  10. prepare_model_for_int8_training,
  11. )
  12. from fed_utils import FedAvg, client_selection, global_evaluation, GeneralClient
  13. import datasets
  14. from utils.prompter import Prompter
  15. datasets.utils.logging.set_verbosity_error()
  16. def fl_finetune(
  17. # model/data params
  18. global_model: str = '',
  19. data_path: str = './data',
  20. output_dir: str = './lora-shepherd/',
  21. # FL hyperparamas
  22. client_selection_strategy: str = 'random',
  23. client_selection_frac: float = 0.1,
  24. num_communication_rounds: int = 50,
  25. num_clients: int = 10,
  26. # Local training hyperparams
  27. local_batch_size: int = 64, # 64,
  28. local_micro_batch_size: int = 8,
  29. local_num_epochs: int = 10,
  30. local_learning_rate: float = 3e-4,
  31. local_val_set_size: int = 0,
  32. local_save_steps: int = 3,
  33. cutoff_len: int = 512,
  34. # LoRA hyperparams
  35. lora_r: int = 16,
  36. lora_alpha: int = 16,
  37. lora_dropout: float = 0.05,
  38. lora_target_modules: List[str] = [
  39. "q_proj",
  40. ],
  41. # llm hyperparams
  42. train_on_inputs: bool = True,
  43. group_by_length: bool = False,
  44. resume_from_checkpoint: str = None, # either training checkpoint or final adapter
  45. prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
  46. ):
  47. if int(os.environ.get("LOCAL_RANK", 0)) == 0:
  48. print(
  49. f"Federated Finetuning LLM-LoRA with params:\n"
  50. f"global_model: {global_model}\n"
  51. f"data_path: {data_path}\n"
  52. f"output_dir: {output_dir}\n"
  53. f"client_selection_strategy: {client_selection_strategy}\n"
  54. f"client_selection_frac: {client_selection_frac}\n"
  55. f"num_communication_rounds: {num_communication_rounds}\n"
  56. f"num_clients: {num_clients}\n"
  57. f"local_batch_size: {local_batch_size}\n"
  58. f"local_micro_batch_size: {local_micro_batch_size}\n"
  59. f"local_num_epochs: {local_num_epochs}\n"
  60. f"local_learning_rate: {local_learning_rate}\n"
  61. f"local_val_set_size: {local_val_set_size}\n"
  62. f"local_save_steps: {local_save_steps}\n"
  63. f"cutoff_len: {cutoff_len}\n"
  64. f"lora_r: {lora_r}\n"
  65. f"lora_alpha: {lora_alpha}\n"
  66. f"lora_dropout: {lora_dropout}\n"
  67. f"lora_target_modules: {lora_target_modules}\n"
  68. f"train_on_inputs: {train_on_inputs}\n"
  69. f"group_by_length: {group_by_length}\n"
  70. f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
  71. f"prompt template: {prompt_template_name}\n"
  72. )
  73. assert (
  74. global_model
  75. ), "Please specify a --global_model, e.g. --global_modell='decapoda-research/llama-7b-hf'"
  76. data_path = os.path.join(data_path, str(num_clients))
  77. assert (os.path.exists(data_path), "Please generate the data files for each client")
  78. # set up the global model & toknizer
  79. gradient_accumulation_steps = local_batch_size // local_micro_batch_size
  80. prompter = Prompter(prompt_template_name)
  81. device_map = "auto"
  82. world_size = int(os.environ.get("WORLD_SIZE", 1))
  83. ddp = world_size != 1
  84. if ddp:
  85. device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  86. gradient_accumulation_steps = gradient_accumulation_steps // world_size
  87. model = LlamaForCausalLM.from_pretrained(
  88. global_model,
  89. load_in_8bit=True,
  90. torch_dtype=torch.float16,
  91. device_map=device_map,
  92. )
  93. tokenizer = LlamaTokenizer.from_pretrained(global_model)
  94. tokenizer.pad_token_id = (
  95. 0
  96. )
  97. tokenizer.padding_side = "left"
  98. def tokenize(prompt, add_eos_token=True):
  99. result = tokenizer(
  100. prompt,
  101. truncation=True,
  102. max_length=cutoff_len,
  103. padding=False,
  104. return_tensors=None,
  105. )
  106. if (
  107. result["input_ids"][-1] != tokenizer.eos_token_id
  108. and len(result["input_ids"]) < cutoff_len
  109. and add_eos_token
  110. ):
  111. result["input_ids"].append(tokenizer.eos_token_id)
  112. result["attention_mask"].append(1)
  113. result["labels"] = result["input_ids"].copy()
  114. return result
  115. def generate_and_tokenize_prompt(data_point):
  116. full_prompt = prompter.generate_prompt(
  117. data_point["instruction"],
  118. data_point["context"],
  119. data_point["response"],
  120. )
  121. tokenized_full_prompt = tokenize(full_prompt)
  122. if not train_on_inputs:
  123. user_prompt = prompter.generate_prompt(
  124. data_point["instruction"], data_point["context"]
  125. )
  126. tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
  127. user_prompt_len = len(tokenized_user_prompt["input_ids"])
  128. tokenized_full_prompt["labels"] = [
  129. -100
  130. ] * user_prompt_len + tokenized_full_prompt["labels"][
  131. user_prompt_len:
  132. ] # could be sped up, probably
  133. return tokenized_full_prompt
  134. model = prepare_model_for_int8_training(model)
  135. config = LoraConfig(
  136. r=lora_r,
  137. lora_alpha=lora_alpha,
  138. target_modules=lora_target_modules,
  139. lora_dropout=lora_dropout,
  140. bias="none",
  141. task_type="CAUSAL_LM",
  142. )
  143. model = get_peft_model(model, config)
  144. if not ddp and torch.cuda.device_count() > 1:
  145. model.is_parallelizable = True
  146. model.model_parallel = True
  147. print("The process of federated instruction-tuning has started..")
  148. previously_selected_clients_set = set()
  149. last_client_id = None
  150. local_dataset_len_dict = dict()
  151. output_dir = os.path.join(output_dir, str(num_clients))
  152. for epoch in tqdm(range(num_communication_rounds)):
  153. print("\nConducting the client selection")
  154. selected_clients_set = client_selection(num_clients, client_selection_frac, client_selection_strategy,
  155. other_info=epoch)
  156. for client_id in selected_clients_set:
  157. client = GeneralClient(client_id, model, data_path, output_dir)
  158. print("\nPreparing the local dataset and trainer for Client_{}".format(client_id))
  159. client.preprare_local_dataset(generate_and_tokenize_prompt, local_val_set_size)
  160. client.build_local_trainer(tokenizer,
  161. local_micro_batch_size,
  162. gradient_accumulation_steps,
  163. local_num_epochs,
  164. local_learning_rate,
  165. group_by_length,
  166. ddp)
  167. print("Initiating the local training of Client_{}".format(client_id))
  168. client.initiate_local_training()
  169. print("Local training starts ... ")
  170. client.train()
  171. print("\nTerminating the local training of Client_{}".format(client_id))
  172. model, local_dataset_len_dict, previously_selected_clients_set, last_client_id = client.terminate_local_training(
  173. epoch, local_dataset_len_dict, previously_selected_clients_set)
  174. del client
  175. print("Collecting the weights of clients and performing aggregation")
  176. model = FedAvg(model,
  177. selected_clients_set,
  178. output_dir,
  179. local_dataset_len_dict,
  180. epoch,
  181. )
  182. torch.save(model.state_dict(), os.path.join(output_dir, str(epoch), "adapter_model.bin"))
  183. config.save_pretrained(output_dir)
  184. # Please design the evaluation method based on your specific requirements in the fed_utils/evaluation.py file.
  185. global_evaluation()
  186. if __name__ == "__main__":
  187. fire.Fire(fl_finetune)