input_main.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import json
  2. from langchain.chat_models import ChatOpenAI
  3. from langchain.prompts.chat import (
  4. ChatPromptTemplate,
  5. SystemMessagePromptTemplate,
  6. HumanMessagePromptTemplate,
  7. )
  8. from Evaluate_langchain import evaluate_langchain
  9. from event_langchain import event_langchain
  10. from format_langchain import format_langchain
  11. from global_langchain import global_model
  12. from value_langchain import value_langchain
  13. def input_langchain():
  14. print('[初始化]--------------------------初始化中--------------------------')
  15. chat = ChatOpenAI(temperature=0)
  16. model = global_model()
  17. # output_type = ""
  18. # for key, value in output_declare.items():
  19. # output_type += key + ":" + value + ","
  20. output_type = str(model.intro_declare)
  21. system_template = model.input_system_template
  22. system_message_prompt = SystemMessagePromptTemplate.from_template(
  23. system_template)
  24. human_template = model.input_human_template
  25. human_message_prompt = HumanMessagePromptTemplate.from_template(
  26. human_template)
  27. chat_prompt = ChatPromptTemplate.from_messages(
  28. [system_message_prompt, human_message_prompt])
  29. rsp = chat(
  30. chat_prompt.format_prompt(output_type=output_type,
  31. style=model.style,
  32. story=model.story).to_messages())
  33. print('[初始化]--------------------------GPT得到--------------------------')
  34. print(rsp.content)
  35. try:
  36. game_intro = json.loads(rsp.content)
  37. except:
  38. print('格式错误,修复中')
  39. game_intro = format_langchain(rsp.content, output_type)
  40. print(
  41. '[初始化]--------------------------json.loads得到--------------------------'
  42. )
  43. print(game_intro)
  44. model.input_init(game_intro["故事简介"], game_intro["角色设定"],
  45. game_intro["数值系统"], game_intro["游戏通关所需条件"])
  46. value_langchain()
  47. return game_intro