main.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import gradio as gr
  2. from utils.openai_api import get_reply, get_tokens_length
  3. from utils.read_file import read_xlsx_file, process_data
  4. sessions = {}
  5. def get_session(key):
  6. if key not in sessions:
  7. sessions[key] = {"count_now": 0, "reply": ""}
  8. return sessions[key]
  9. def submit(action, file, max_length, message_first, message_after,
  10. session_key):
  11. session = get_session(session_key)
  12. count_now = session["count_now"]
  13. reply = session["reply"]
  14. if action == "预览消息":
  15. data = read_xlsx_file(file.name)
  16. if count_now == 0:
  17. message, count = process_data(max_length, data, message_first,
  18. len(data))
  19. else:
  20. message, count = process_data(max_length,
  21. data,
  22. message_after,
  23. len(data),
  24. count_now,
  25. reply=reply)
  26. return get_tokens_length(message), message, ""
  27. elif action == "获取最终回复(耗时较长)":
  28. while True:
  29. data = read_xlsx_file(file.name)
  30. if count_now == 0:
  31. message, count = process_data(max_length, data, message_first,
  32. len(data))
  33. else:
  34. message, count = process_data(max_length,
  35. data,
  36. message_after,
  37. len(data),
  38. count_now,
  39. reply=reply)
  40. count_now += count
  41. if count == 0:
  42. break
  43. else:
  44. print(message)
  45. reply = get_reply(message)
  46. print(reply)
  47. session["count_now"] = count_now
  48. session["reply"] = reply
  49. return get_tokens_length(message), "", reply
  50. elif action == "获取一轮回复":
  51. data = read_xlsx_file(file.name)
  52. if count_now == 0:
  53. message, count = process_data(max_length, data, message_first,
  54. len(data))
  55. else:
  56. message, count = process_data(max_length,
  57. data,
  58. message_after,
  59. len(data),
  60. count_now,
  61. reply=reply)
  62. count_now += count
  63. reply = get_reply(message)
  64. session["count_now"] = count_now
  65. session["reply"] = reply
  66. return get_tokens_length(message), message, reply
  67. iface = gr.Interface(
  68. fn=submit,
  69. inputs=[
  70. gr.inputs.Dropdown(choices=["预览消息", "获取最终回复(耗时较长)", "获取一轮回复"],
  71. label="操作",
  72. default="预览消息"),
  73. gr.inputs.File(label="上传 xlsx 文件"),
  74. gr.inputs.Number(default=2048, label="长度限制(最大4096)"),
  75. gr.inputs.Textbox(
  76. lines=2,
  77. label="第一轮预定义消息",
  78. default=
  79. "接下来输入第{comment_num_start}条-第{comment_num_end}条app store中对一款app的评论(共{all_num}条),格式为[(标题,内容),(标题,内容)...],请分条总结这款app的好评论与坏评论(各十条)(并在每条后面按照百分比给出这个观点的在好/坏评论中的当前占比),{data_string}"
  80. ),
  81. gr.inputs.Textbox(
  82. lines=2,
  83. label="后续预定义消息",
  84. default=
  85. "{reply_before},以上是对前{comment_num}条的分析结果,接下来分段输入第{comment_num_start}条-第{comment_num_end}条app store中对一款app的评论(共{all_num}条),格式为[(标题,内容),(标题,内容)...],请分条总结这款app的好评论与坏评论(各十条)(并在每条后面按照百分比给出这个观点的在好/坏评论中的当前占比),{data_string}"
  86. ),
  87. gr.inputs.Textbox(lines=1,
  88. label="session_key",
  89. default="new_session_id(用于记录会话))"),
  90. ],
  91. outputs=[
  92. gr.outputs.Textbox(label="Token长度"),
  93. gr.outputs.Textbox(label="GPT输入的字符串(Message)"),
  94. gr.outputs.Textbox(label="GPT输出的字符串(Reply)"),
  95. ],
  96. layout="vertical",
  97. title="GPT4Comment",
  98. description="上传xlsx文件并输入预定义的消息,然后点击发送使用GPT",
  99. )
  100. iface.launch(debug=True)