main.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import gradio as gr
  2. from utils.openai_api import get_reply, get_tokens_length
  3. from utils.read_file import read_xlsx_file, process_data
  4. sessions = {}
  5. def get_session(key):
  6. if key not in sessions:
  7. sessions[key] = {"count_now": 0, "reply": ""}
  8. return sessions[key]
  9. def submit(action, data_source, file, max_length, message_first, message_after,
  10. session_key):
  11. session = get_session(session_key)
  12. count_now = session["count_now"]
  13. reply = session["reply"]
  14. if action == "预览消息":
  15. data = read_xlsx_file(data_source, file.name)
  16. if count_now == 0:
  17. message, count = process_data(data_source, max_length, data,
  18. message_first, len(data))
  19. else:
  20. message, count = process_data(data_source,
  21. max_length,
  22. data,
  23. message_after,
  24. len(data),
  25. count_now,
  26. reply=reply)
  27. return get_tokens_length(message), message, ""
  28. elif action == "获取最终回复(耗时较长)":
  29. while True:
  30. data = read_xlsx_file(data_source, file.name)
  31. if count_now == 0:
  32. message, count = process_data(data_source, max_length, data,
  33. message_first, len(data))
  34. else:
  35. message, count = process_data(data_source,
  36. max_length,
  37. data,
  38. message_after,
  39. len(data),
  40. count_now,
  41. reply=reply)
  42. count_now += count
  43. if count == 0:
  44. break
  45. else:
  46. print(message)
  47. reply = get_reply(message)
  48. print(reply)
  49. session["count_now"] = count_now
  50. session["reply"] = reply
  51. return get_tokens_length(message), "", reply
  52. elif action == "获取一轮回复":
  53. data = read_xlsx_file(data_source, file.name)
  54. if count_now == 0:
  55. message, count = process_data(data_source, max_length, data,
  56. message_first, len(data))
  57. else:
  58. message, count = process_data(data_source,
  59. max_length,
  60. data,
  61. message_after,
  62. len(data),
  63. count_now,
  64. reply=reply)
  65. count_now += count
  66. reply = get_reply(message)
  67. session["count_now"] = count_now
  68. session["reply"] = reply
  69. return get_tokens_length(message), message, reply
  70. iface = gr.Interface(
  71. fn=submit,
  72. inputs=[
  73. gr.inputs.Dropdown(choices=["预览消息", "获取最终回复(耗时较长)", "获取一轮回复"],
  74. label="操作",
  75. default="预览消息"),
  76. gr.inputs.Dropdown(choices=["七麦", "点点"], label="数据来源", default="七麦"),
  77. gr.inputs.File(label="上传 xlsx 文件"),
  78. gr.inputs.Number(default=2048, label="长度限制(中文建议最大4096)"),
  79. gr.inputs.Textbox(
  80. lines=2,
  81. label="第一轮预定义消息",
  82. default=
  83. "接下来输入第{comment_num_start}条-第{comment_num_end}条app store中对一款app的评论(共{all_num}条),请分条总结这款app的好评论与坏评论(各十条)(并在每条后面按照百分比给出这个观点的在好/坏评论中的当前占比),{data_string}"
  84. ),
  85. gr.inputs.Textbox(
  86. lines=2,
  87. label="后续预定义消息",
  88. default=
  89. "{reply_before},以上是对前{comment_num}条的分析结果,接下来分段输入第{comment_num_start}条-第{comment_num_end}条app store中对一款app的评论(共{all_num}条),请分条总结这款app的好评论与坏评论(各十条)(并在每条后面按照百分比给出这个观点的在好/坏评论中的当前占比),{data_string}"
  90. ),
  91. gr.inputs.Textbox(lines=1,
  92. label="session_key",
  93. default="new_session_id(用于记录会话))"),
  94. ],
  95. outputs=[
  96. gr.outputs.Textbox(label="Token长度"),
  97. gr.outputs.Textbox(label="GPT输入的字符串(Message)"),
  98. gr.outputs.Textbox(label="GPT输出的字符串(Reply)"),
  99. ],
  100. layout="vertical",
  101. title="GPT4Comment",
  102. description="上传xlsx文件并输入预定义的消息,然后点击发送使用GPT",
  103. )
  104. iface.launch(debug=True)
  105. # iface.launch(server_name="0.0.0.0", server_port=7881)