chat_module.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import json
  2. import os
  3. from langchain.chat_models import ChatOpenAI
  4. from langchain.prompts.chat import (
  5. ChatPromptTemplate,
  6. SystemMessagePromptTemplate,
  7. HumanMessagePromptTemplate,
  8. )
  9. from utils.configuration import Configuration
  10. from utils.format_error_data import format_error_data
  11. from utils.utils import convert_choice
  12. def chat_module(classname,id,input,prompt,output):
  13. print('--------------------------Chat--------------------------')
  14. chat = ChatOpenAI(temperature=0)
  15. system_template = prompt["system_template"]
  16. system_message_prompt = SystemMessagePromptTemplate.from_template(
  17. system_template)
  18. human_template = prompt["human_template"]
  19. human_message_prompt = HumanMessagePromptTemplate.from_template(
  20. human_template)
  21. chat_prompt = ChatPromptTemplate.from_messages(
  22. [system_message_prompt, human_message_prompt])
  23. input_prompt=chat_prompt.format_prompt(**{key: value for key, value in input.items()}).to_messages()
  24. print("输入prompt为:"+str(input_prompt))
  25. rsp = chat(input_prompt)
  26. print('--------------------------Output--------------------------')
  27. try:
  28. game_event = json.loads(rsp.content, object_hook=convert_choice)
  29. print(game_event)
  30. save_2_json(classname,id,output,game_event)
  31. except:
  32. print('JSON格式解析错误,Chat输出如下,尝试修复')
  33. print(rsp.content)
  34. game_event=format_error_data(rsp.content)
  35. if game_event is not None:
  36. save_2_json(classname,id,output,game_event)
  37. def save_2_json(classname,id,output_type,data):
  38. configuration=Configuration()
  39. file_path="Chat_"+classname+"_output.json"
  40. configuration.save_2_json(file_path,output_type,data)