format_langchain.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import json
  2. from langchain.chat_models import ChatOpenAI
  3. from langchain.prompts.chat import (
  4. ChatPromptTemplate,
  5. SystemMessagePromptTemplate,
  6. HumanMessagePromptTemplate,
  7. )
  8. from global_langchain import global_model
  9. from utils.utils import convert_choice
  10. def format_langchain(text, output_type):
  11. print('[修复]--------------------------修复格式中--------------------------')
  12. chat = ChatOpenAI(temperature=0)
  13. model = global_model()
  14. system_template = model.format_system_template
  15. system_message_prompt = SystemMessagePromptTemplate.from_template(
  16. system_template)
  17. human_template = model.format_human_template
  18. human_message_prompt = HumanMessagePromptTemplate.from_template(
  19. human_template)
  20. chat_prompt = ChatPromptTemplate.from_messages(
  21. [system_message_prompt, human_message_prompt])
  22. rsp = chat(
  23. chat_prompt.format_prompt(output_type=output_type,
  24. text=text).to_messages())
  25. print('[修复]--------------------------GPT得到--------------------------')
  26. print(rsp.content)
  27. try:
  28. result = json.loads(rsp.content, object_hook=convert_choice)
  29. except:
  30. result = format_langchain(text, output_type)
  31. print(
  32. '[修复]--------------------------json.loads得到--------------------------')
  33. print(result)
  34. return result