card_handler.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import json
  2. import threading
  3. from concurrent.futures import ThreadPoolExecutor
  4. from api.img2img import img_2_img
  5. from api.segment import get_segment_mask
  6. from utils.config_io import load_config_from_json, save_config_to_json
  7. from utils.image_io import get_img_file, get_img_key_from_input_images
  8. from utils.request_api import reply_message, update_message, upload_image
  9. from utils.logger import logger
  10. from utils.template_io import load_template, save_process_card, update_process_card
  11. ACTION_TO_TITLE = {
  12. 'clothes': '换装',
  13. 'furry': 'Furry头像'
  14. }
  15. def handle_card(card_data):
  16. reply_message_id = card_data['open_message_id']
  17. action = card_data['action']
  18. if action['tag'] == 'select_static':
  19. # 用户进行选择
  20. config_data = load_config_from_json(reply_message_id)
  21. config_data[action['value']['type']] = action['option']
  22. save_config_to_json(config_data, reply_message_id)
  23. if config_data['action'] != "" and config_data['lora'] != "":
  24. response_data = choose_updated_card(reply_message_id, config_data)
  25. return response_data
  26. else:
  27. # 用户点击开始
  28. # TODO: 保存任务状态实现断点继续
  29. if action['value']['key'] == 'start':
  30. config_data = load_config_from_json(reply_message_id)
  31. if config_data['action'] != "" and config_data['lora'] != "":
  32. response_data = waiting_card(reply_message_id, config_data)
  33. executor = ThreadPoolExecutor()
  34. executor.submit(start_sd_process, reply_message_id, config_data)
  35. return response_data
  36. else:
  37. logger.info('用户点击了上传中按钮')
  38. def start_sd_process(reply_message_id, config_data):
  39. if config_data['action'] == 'clothes':
  40. # 获取segment mask
  41. start_clothes_img_theards(reply_message_id)
  42. else:
  43. pass
  44. # 开始换装的每个图片的线程
  45. def start_clothes_img_theards(reply_message_id):
  46. img_keys = get_img_key_from_input_images(reply_message_id)
  47. config_data = load_config_from_json(reply_message_id)
  48. config_data["img_missions"] = {}
  49. for key in img_keys:
  50. config_data["img_missions"][key] = "doing"
  51. save_config_to_json(config_data, reply_message_id)
  52. threads = []
  53. for img_key in img_keys:
  54. logger.info("开始clothes处理线程: "+img_key)
  55. # img_clothes(reply_message_id, img_key)
  56. t = threading.Thread(target=img_clothes, args=(reply_message_id, img_key))
  57. threads.append(t)
  58. t.start()
  59. for t in threads:
  60. t.join()
  61. def mark_img_fail(reply_message_id, img_key):
  62. config_data = load_config_from_json(reply_message_id)
  63. config_data["img_missions"][img_key] = "fail"
  64. save_config_to_json(config_data, reply_message_id)
  65. def mark_img_done(reply_message_id, img_key):
  66. config_data = load_config_from_json(reply_message_id)
  67. config_data["img_missions"][img_key] = "done"
  68. save_config_to_json(config_data, reply_message_id)
  69. # 对单个图片做换装操作
  70. def img_clothes(reply_message_id, img_key):
  71. img_path = f"data/input_images/{reply_message_id}/{img_key}/{img_key}.jpg"
  72. output_directory = f"data/input_images/{reply_message_id}/{img_key}"
  73. mask_path = get_segment_mask(img_key, img_path, output_directory)
  74. # 获取mask失败则上传失败图片
  75. if not mask_path:
  76. mask_path = "asset/segment_fail.png"
  77. mark_img_fail(reply_message_id, img_key)
  78. # 上传mask
  79. mask = get_img_file(mask_path)
  80. upload_mask_res = upload_image(img_key, mask, "mask")
  81. if upload_mask_res:
  82. upload_mask_key = upload_mask_res.json()['data']['image_key']
  83. # 更新card
  84. updated_mask_card ={
  85. "msg_type": "interactive",
  86. "content": json.dumps(update_process_card(reply_message_id, img_key, upload_mask_key))
  87. }
  88. update_message(reply_message_id, updated_mask_card)
  89. # 开始换装
  90. img_output_path = img_2_img(get_prompt(reply_message_id), mask, img_key, img_path, output_directory)
  91. # 获取img失败则上传失败图片
  92. if not img_output_path:
  93. img_output_path = "asset/img_fail.png"
  94. mark_img_fail(reply_message_id, img_key)
  95. # 上传图片
  96. output_img = get_img_file(img_output_path)
  97. upload_output_img_res = upload_image(img_key, output_img, "output_img")
  98. if upload_output_img_res:
  99. mark_img_done(reply_message_id, img_key)
  100. upload_output_img_key = upload_output_img_res.json()['data']['image_key']
  101. # 更新card
  102. upload_output_img_card ={
  103. "msg_type": "interactive",
  104. "content": json.dumps(update_process_card(reply_message_id, img_key, upload_output_img_key))
  105. }
  106. update_message(reply_message_id, upload_output_img_card)
  107. else:
  108. mark_img_fail(reply_message_id, img_key)
  109. else:
  110. mark_img_fail(reply_message_id, img_key)
  111. # 配置prompt
  112. def get_prompt(reply_message_id):
  113. config = load_config_from_json(reply_message_id)
  114. return "1girl, masterpiece, best quality, " + config['lora']
  115. # 用户选择卡片
  116. def choose_updated_card(message_id, config_data):
  117. card_content = load_template('start_card', message_id)
  118. card_content['elements'][3]['text']['content'] = get_prompt(message_id)
  119. image_columns = []
  120. img_keys = get_img_key_from_input_images(message_id)
  121. for img_key in img_keys:
  122. image_columns.append(
  123. {
  124. "tag": "column",
  125. "width": "weighted",
  126. "weight": 1,
  127. "vertical_align": "top",
  128. "elements":
  129. [
  130. {
  131. "tag": "img",
  132. "img_key": img_key,
  133. "alt": {
  134. "tag": "plain_text",
  135. "content": ""
  136. },
  137. "mode": "fit_horizontal",
  138. "preview": True
  139. },
  140. ]
  141. }
  142. )
  143. card_content['elements'][1]['columns'] = image_columns
  144. if not config_data['if_download']:
  145. card_content['elements'][4]['actions'][2] = {
  146. "tag": "button",
  147. "text": {
  148. "tag": "lark_md",
  149. "content": "**上传中**"
  150. },
  151. "type": "default",
  152. "value": {
  153. "type" : "uploading",
  154. "key": "uploading"
  155. }
  156. }
  157. return card_content
  158. # 用户等待卡片
  159. def waiting_card(message_id, config_data):
  160. card_content = load_template('waiting_card')
  161. card_content['elements'][3]['text']['content'] = get_prompt(message_id)
  162. image_columns = []
  163. img_keys = get_img_key_from_input_images(message_id)
  164. for img_key in img_keys:
  165. image_columns.append(
  166. {
  167. "tag": "column",
  168. "width": "weighted",
  169. "weight": 1,
  170. "vertical_align": "top",
  171. "elements":
  172. [
  173. {
  174. "tag": "img",
  175. "img_key": img_key,
  176. "alt": {
  177. "tag": "plain_text",
  178. "content": "test"
  179. },
  180. "mode": "fit_horizontal",
  181. "preview": True
  182. },
  183. ]
  184. }
  185. )
  186. card_content['elements'][1]['columns'] = image_columns
  187. save_process_card(card_content, message_id)
  188. return card_content