GlobalModel_generated.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import os
  2. import fire
  3. import gradio as gr
  4. import torch
  5. import transformers
  6. from peft import (
  7. PeftModel,
  8. LoraConfig,
  9. get_peft_model,
  10. get_peft_model_state_dict,
  11. prepare_model_for_int8_training,
  12. set_peft_model_state_dict,
  13. )
  14. from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer,AutoTokenizer
  15. from utils.callbacks import Iteratorize, Stream
  16. from utils.prompter import Prompter
  17. # 检查设备是否可用,决定使用CPU还是GPU,或者使用mps(如果可用)
  18. if torch.cuda.is_available():
  19. device = "cuda"
  20. else:
  21. device = "cpu"
  22. try:
  23. if torch.backends.mps.is_available():
  24. device = "mps"
  25. except:
  26. pass
  27. def main(
  28. load_8bit: bool = False,
  29. base_model: str = "",
  30. lora_weights_path: str = "",
  31. lora_config_path: str= "", # provide only the file path, excluding the file name 'adapter_config.json'
  32. prompt_template: str = "", # The prompt template to use, will default to alpaca.
  33. server_name: str = "0.0.0.0",
  34. share_gradio: bool = False,
  35. ):
  36. # 从命令行参数或环境变量获取基础模型名称
  37. base_model = base_model or os.environ.get("BASE_MODEL", "")
  38. assert (
  39. base_model
  40. ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
  41. prompter = Prompter(prompt_template)
  42. tokenizer = LlamaTokenizer.from_pretrained(base_model)
  43. if not lora_weights_path.endswith(".bin"):
  44. if device == "cuda":
  45. model = LlamaForCausalLM.from_pretrained(
  46. base_model,
  47. load_in_8bit=load_8bit,
  48. torch_dtype=torch.float16,
  49. device_map="auto",
  50. )
  51. model = PeftModel.from_pretrained(
  52. model,
  53. lora_weights_path,
  54. torch_dtype=torch.float16,
  55. )
  56. elif device == "mps":
  57. model = LlamaForCausalLM.from_pretrained(
  58. base_model,
  59. device_map={"": device},
  60. torch_dtype=torch.float16,
  61. )
  62. model = PeftModel.from_pretrained(
  63. model,
  64. lora_weights_path,
  65. device_map={"": device},
  66. torch_dtype=torch.float16,
  67. )
  68. else:
  69. model = LlamaForCausalLM.from_pretrained(
  70. base_model, device_map={"": device}, low_cpu_mem_usage=True
  71. )
  72. model = PeftModel.from_pretrained(
  73. model,
  74. lora_weights_path,
  75. device_map={"": device},
  76. )
  77. else:
  78. model = LlamaForCausalLM.from_pretrained(
  79. base_model,
  80. load_in_8bit=True,
  81. torch_dtype=torch.float16,
  82. device_map="auto",
  83. )
  84. model = prepare_model_for_int8_training(model)
  85. config = LoraConfig.from_pretrained(lora_config_path)
  86. lora_weights = torch.load(lora_weights_path)
  87. model = PeftModel(model, config)
  88. set_peft_model_state_dict(model,lora_weights,"default")
  89. del lora_weights
  90. # unwind broken decapoda-research config
  91. # 进行模型配置修正
  92. model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
  93. model.config.bos_token_id = 1
  94. model.config.eos_token_id = 2
  95. # 为某些用户修复错误
  96. if not load_8bit:
  97. model.half() # seems to fix bugs for some users.
  98. # 设置模型为评估模式
  99. model.eval()
  100. # 定义评估函数,将输入的文本转换为模型可理解的形式,然后生成响应
  101. def evaluate(
  102. instruction,
  103. input=None,
  104. temperature=0.1,
  105. top_p=0.75,
  106. top_k=40,
  107. num_beams=4,
  108. max_new_tokens=128,
  109. stream_output=True,
  110. **kwargs,
  111. ):
  112. # 生成提示,并将其转换为模型所需的输入格式
  113. prompt = prompter.generate_prompt(instruction, input)
  114. inputs = tokenizer(prompt, return_tensors="pt")
  115. input_ids = inputs["input_ids"].to(device)
  116. # 配置生成参数
  117. generation_config = GenerationConfig(
  118. temperature=temperature,
  119. top_p=top_p,
  120. top_k=top_k,
  121. num_beams=num_beams,
  122. do_sample=True,
  123. **kwargs,
  124. )
  125. generate_params = {
  126. "input_ids": input_ids,
  127. "generation_config": generation_config,
  128. "return_dict_in_generate": True,
  129. "output_scores": True,
  130. "max_new_tokens": max_new_tokens,
  131. }
  132. # 如果stream_output为True,则以流的方式生成和返回响应
  133. if stream_output:
  134. # Stream the reply 1 token at a time.
  135. # This is based on the trick of using 'stopping_criteria' to create an iterator,
  136. # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
  137. def generate_with_callback(callback=None, **kwargs):
  138. kwargs.setdefault(
  139. "stopping_criteria", transformers.StoppingCriteriaList()
  140. )
  141. kwargs["stopping_criteria"].append(
  142. Stream(callback_func=callback)
  143. )
  144. with torch.no_grad():
  145. model.generate(**kwargs)
  146. def generate_with_streaming(**kwargs):
  147. return Iteratorize(
  148. generate_with_callback, kwargs, callback=None
  149. )
  150. # 使用迭代器方式生成响应
  151. with generate_with_streaming(**generate_params) as generator:
  152. for output in generator:
  153. # new_tokens = len(output) - len(input_ids[0])
  154. decoded_output = tokenizer.decode(output)
  155. if output[-1] in [tokenizer.eos_token_id]:
  156. break
  157. yield prompter.get_response(decoded_output)
  158. return # early return for stream_output
  159. # Without streaming
  160. with torch.no_grad():
  161. generation_output = model.generate(
  162. input_ids=input_ids,
  163. generation_config=generation_config,
  164. return_dict_in_generate=True,
  165. output_scores=True,
  166. max_new_tokens=max_new_tokens,
  167. )
  168. s = generation_output.sequences[0]
  169. output = tokenizer.decode(s)
  170. yield prompter.get_response(output)
  171. sherpherd_UI=gr.Interface(
  172. fn=evaluate,
  173. inputs=[
  174. gr.components.Textbox(
  175. lines=2,
  176. label="Instruction",
  177. placeholder="Tell me about alpacas.",
  178. ),
  179. gr.components.Textbox(lines=2, label="Input", placeholder="none"),
  180. gr.components.Slider(
  181. minimum=0, maximum=1, value=0.01, label="Temperature"
  182. ),
  183. gr.components.Slider(
  184. minimum=0, maximum=1, value=0.03, label="Top p"
  185. ),
  186. gr.components.Slider(
  187. minimum=0, maximum=100, step=1, value=1, label="Top k"
  188. ),
  189. gr.components.Slider(
  190. minimum=1, maximum=4, step=1, value=1, label="Beams"
  191. ),
  192. gr.components.Slider(
  193. minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
  194. ),
  195. gr.components.Checkbox(label="Stream output"),
  196. ],
  197. outputs=[
  198. gr.components.Textbox(lines=5, label="Output")
  199. ],
  200. title="FederatedGPT-shepherd",
  201. description="Shepherd is a LLM that has been fine-tuned in a federated manner ",
  202. ).queue()
  203. sherpherd_UI.launch(server_name=server_name, share=share_gradio)
  204. if __name__ == "__main__":
  205. fire.Fire(main)