card_handler.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import json
  2. import threading
  3. from concurrent.futures import ThreadPoolExecutor
  4. from api.img2img import img_2_img
  5. from api.segment import get_segment_mask
  6. from utils.config_io import load_config_from_json, save_config_to_json
  7. from utils.image_io import get_img_file, get_img_key_from_input_images
  8. from utils.request_api import reply_message, update_message, upload_image
  9. from utils.logger import logger
  10. from utils.template_io import load_template, save_process_card, update_process_card
  11. ACTION_TO_TITLE = {
  12. 'clothes': '换装',
  13. 'furry': 'Furry头像'
  14. }
  15. clothes_lock = threading.Lock()
  16. def handle_card(card_data):
  17. reply_message_id = card_data['open_message_id']
  18. action = card_data['action']
  19. if action['tag'] == 'select_static':
  20. # 用户进行选择
  21. config_data = load_config_from_json(reply_message_id)
  22. config_data[action['value']['type']] = action['option']
  23. save_config_to_json(config_data, reply_message_id)
  24. if config_data['action'] != "" and config_data['lora'] != "":
  25. response_data = choose_updated_card(reply_message_id, config_data)
  26. return response_data
  27. else:
  28. # 用户点击开始
  29. # TODO: 保存任务状态实现断点继续
  30. if action['value']['key'] == 'start':
  31. config_data = load_config_from_json(reply_message_id)
  32. if config_data['action'] != "" and config_data['lora'] != "":
  33. response_data = waiting_card(reply_message_id, config_data)
  34. executor = ThreadPoolExecutor()
  35. executor.submit(start_sd_process, reply_message_id, config_data)
  36. return response_data
  37. else:
  38. logger.info('用户点击了上传中按钮')
  39. def start_sd_process(reply_message_id, config_data):
  40. if config_data['action'] == 'clothes':
  41. # 获取segment mask
  42. start_clothes_img_theards(reply_message_id)
  43. else:
  44. pass
  45. # 开始换装的每个图片的线程
  46. def start_clothes_img_theards(reply_message_id):
  47. img_keys = get_img_key_from_input_images(reply_message_id)
  48. config_data = load_config_from_json(reply_message_id)
  49. config_data["img_missions"] = {}
  50. for key in img_keys:
  51. config_data["img_missions"][key] = "doing"
  52. save_config_to_json(config_data, reply_message_id)
  53. threads = []
  54. for img_key in img_keys:
  55. logger.info("开始clothes处理线程: "+img_key)
  56. with clothes_lock:
  57. # 在锁的上下文中启动线程,确保一次只启动一个
  58. t = threading.Thread(target=img_clothes, args=(reply_message_id, img_key))
  59. threads.append(t)
  60. t.start()
  61. t.join() # 等待当前线程完成再启动下一个
  62. def mark_img_fail(reply_message_id, img_key):
  63. config_data = load_config_from_json(reply_message_id)
  64. config_data["img_missions"][img_key] = "fail"
  65. save_config_to_json(config_data, reply_message_id)
  66. def mark_img_done(reply_message_id, img_key):
  67. config_data = load_config_from_json(reply_message_id)
  68. config_data["img_missions"][img_key] = "done"
  69. save_config_to_json(config_data, reply_message_id)
  70. # 对单个图片做换装操作
  71. def img_clothes(reply_message_id, img_key):
  72. img_path = f"data/input_images/{reply_message_id}/{img_key}/{img_key}.jpg"
  73. output_directory = f"data/input_images/{reply_message_id}/{img_key}"
  74. mask_path = get_segment_mask(img_key, img_path, output_directory)
  75. # 获取mask失败则上传失败图片
  76. if not mask_path:
  77. mask_path = "asset/segment_fail.png"
  78. mark_img_fail(reply_message_id, img_key)
  79. # 上传mask
  80. mask = get_img_file(mask_path)
  81. upload_mask_res = upload_image(img_key, mask, "mask")
  82. if upload_mask_res:
  83. upload_mask_key = upload_mask_res.json()['data']['image_key']
  84. # 更新card
  85. updated_mask_card ={
  86. "msg_type": "interactive",
  87. "content": json.dumps(update_process_card(reply_message_id, img_key, upload_mask_key))
  88. }
  89. update_message(reply_message_id, updated_mask_card)
  90. # 开始换装
  91. img_output_path = img_2_img(get_prompt(reply_message_id), mask, img_key, img_path, output_directory)
  92. # 获取img失败则上传失败图片
  93. if not img_output_path:
  94. img_output_path = "asset/img_fail.png"
  95. mark_img_fail(reply_message_id, img_key)
  96. # 上传图片
  97. output_img = get_img_file(img_output_path)
  98. upload_output_img_res = upload_image(img_key, output_img, "output_img")
  99. if upload_output_img_res:
  100. mark_img_done(reply_message_id, img_key)
  101. upload_output_img_key = upload_output_img_res.json()['data']['image_key']
  102. # 更新card
  103. upload_output_img_card ={
  104. "msg_type": "interactive",
  105. "content": json.dumps(update_process_card(reply_message_id, img_key, upload_output_img_key))
  106. }
  107. update_message(reply_message_id, upload_output_img_card)
  108. else:
  109. mark_img_fail(reply_message_id, img_key)
  110. else:
  111. mark_img_fail(reply_message_id, img_key)
  112. # 配置prompt
  113. def get_prompt(reply_message_id):
  114. config = load_config_from_json(reply_message_id)
  115. return "1girl, masterpiece, best quality, " + config['lora']
  116. # 用户选择卡片
  117. def choose_updated_card(message_id, config_data):
  118. card_content = load_template('start_card', message_id)
  119. card_content['elements'][3]['text']['content'] = get_prompt(message_id)
  120. image_columns = []
  121. img_keys = get_img_key_from_input_images(message_id)
  122. for img_key in img_keys:
  123. image_columns.append(
  124. {
  125. "tag": "column",
  126. "width": "weighted",
  127. "weight": 1,
  128. "vertical_align": "top",
  129. "elements":
  130. [
  131. {
  132. "tag": "img",
  133. "img_key": img_key,
  134. "alt": {
  135. "tag": "plain_text",
  136. "content": ""
  137. },
  138. "mode": "fit_horizontal",
  139. "preview": True
  140. },
  141. ]
  142. }
  143. )
  144. card_content['elements'][1]['columns'] = image_columns
  145. if not config_data['if_download']:
  146. card_content['elements'][4]['actions'][2] = {
  147. "tag": "button",
  148. "text": {
  149. "tag": "lark_md",
  150. "content": "**上传中**"
  151. },
  152. "type": "default",
  153. "value": {
  154. "type" : "uploading",
  155. "key": "uploading"
  156. }
  157. }
  158. return card_content
  159. # 用户等待卡片
  160. def waiting_card(message_id, config_data):
  161. card_content = load_template('waiting_card')
  162. card_content['elements'][3]['text']['content'] = get_prompt(message_id)
  163. image_columns = []
  164. img_keys = get_img_key_from_input_images(message_id)
  165. for img_key in img_keys:
  166. image_columns.append(
  167. {
  168. "tag": "column",
  169. "width": "weighted",
  170. "weight": 1,
  171. "vertical_align": "top",
  172. "elements":
  173. [
  174. {
  175. "tag": "img",
  176. "img_key": img_key,
  177. "alt": {
  178. "tag": "plain_text",
  179. "content": "test"
  180. },
  181. "mode": "fit_horizontal",
  182. "preview": True
  183. },
  184. ]
  185. }
  186. )
  187. card_content['elements'][1]['columns'] = image_columns
  188. save_process_card(card_content, message_id)
  189. return card_content