card_handler.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import json
  2. import threading
  3. from concurrent.futures import ThreadPoolExecutor
  4. from api.img2img import img_2_img, img_2_furry
  5. from api.segment import get_segment_mask
  6. from utils.config_io import load_config_from_json, save_config_to_json
  7. from utils.image_io import get_img_file, get_img_key_from_input_images
  8. from utils.request_api import reply_message, update_message, upload_image
  9. from utils.logger import logger
  10. from utils.template_io import load_template, save_process_card, update_process_card
  11. ACTION_TO_TITLE = {
  12. 'clothes': '换装',
  13. 'furry': 'Furry头像'
  14. }
  15. clothes_lock = threading.Lock()
  16. furry_lock = threading.Lock()
  17. def handle_card(card_data):
  18. reply_message_id = card_data['open_message_id']
  19. action = card_data['action']
  20. if action['tag'] == 'select_static':
  21. # 用户进行选择
  22. config_data = load_config_from_json(reply_message_id)
  23. config_data[action['value']['type']] = action['option']
  24. save_config_to_json(config_data, reply_message_id)
  25. if config_data['action'] != "" and config_data['lora'] != "":
  26. response_data = choose_updated_card(reply_message_id, config_data)
  27. return response_data
  28. else:
  29. # 用户点击开始
  30. # TODO: 保存任务状态实现断点继续
  31. if action['value']['key'] == 'start':
  32. config_data = load_config_from_json(reply_message_id)
  33. logger.info('用户点击了开始按钮')
  34. if config_data['action'] != "" and config_data['lora'] != "":
  35. response_data = waiting_card(reply_message_id, config_data)
  36. executor = ThreadPoolExecutor()
  37. executor.submit(start_sd_process, reply_message_id, config_data)
  38. return response_data
  39. if config_data['action'] != "" and config_data['animal'] != "":
  40. response_data = waiting_card(reply_message_id, config_data)
  41. executor = ThreadPoolExecutor()
  42. executor.submit(start_sd_process, reply_message_id, config_data)
  43. return response_data
  44. else:
  45. logger.info('用户点击了上传中按钮')
  46. def start_sd_process(reply_message_id, config_data):
  47. if config_data['action'] == 'clothes':
  48. # 获取segment mask
  49. start_clothes_img_theards(reply_message_id)
  50. else:
  51. start_furry_img_theards(reply_message_id)
  52. pass
  53. # 开始furry的每个图片的线程
  54. def start_furry_img_theards(reply_message_id):
  55. img_keys = get_img_key_from_input_images(reply_message_id)
  56. config_data = load_config_from_json(reply_message_id)
  57. config_data["img_missions"] = {}
  58. for key in img_keys:
  59. config_data["img_missions"][key] = "doing"
  60. save_config_to_json(config_data, reply_message_id)
  61. threads = []
  62. for img_key in img_keys:
  63. logger.info("开始furry处理线程: "+img_key)
  64. with furry_lock:
  65. # 在锁的上下文中启动线程,确保一次只启动一个
  66. t = threading.Thread(target=img_furry, args=(reply_message_id, img_key))
  67. threads.append(t)
  68. t.start()
  69. t.join() # 等待当前线程完成再启动下一个
  70. # 开始换装的每个图片的线程
  71. def start_clothes_img_theards(reply_message_id):
  72. img_keys = get_img_key_from_input_images(reply_message_id)
  73. config_data = load_config_from_json(reply_message_id)
  74. config_data["img_missions"] = {}
  75. for key in img_keys:
  76. config_data["img_missions"][key] = "doing"
  77. save_config_to_json(config_data, reply_message_id)
  78. threads = []
  79. for img_key in img_keys:
  80. logger.info("开始clothes处理线程: "+img_key)
  81. with clothes_lock:
  82. # 在锁的上下文中启动线程,确保一次只启动一个
  83. t = threading.Thread(target=img_clothes, args=(reply_message_id, img_key))
  84. threads.append(t)
  85. t.start()
  86. t.join() # 等待当前线程完成再启动下一个
  87. def mark_img_fail(reply_message_id, img_key):
  88. config_data = load_config_from_json(reply_message_id)
  89. config_data["img_missions"][img_key] = "fail"
  90. save_config_to_json(config_data, reply_message_id)
  91. def mark_img_done(reply_message_id, img_key):
  92. config_data = load_config_from_json(reply_message_id)
  93. config_data["img_missions"][img_key] = "done"
  94. save_config_to_json(config_data, reply_message_id)
  95. # 对单个图片做换装操作
  96. def img_clothes(reply_message_id, img_key):
  97. img_path = f"data/input_images/{reply_message_id}/{img_key}/{img_key}.jpg"
  98. output_directory = f"data/input_images/{reply_message_id}/{img_key}"
  99. mask_path = get_segment_mask(img_key, img_path, output_directory)
  100. # 获取mask失败则上传失败图片
  101. if not mask_path:
  102. mask_path = "asset/segment_fail.png"
  103. mark_img_fail(reply_message_id, img_key)
  104. # 上传mask
  105. mask = get_img_file(mask_path)
  106. upload_mask_res = upload_image(img_key, mask, "mask")
  107. if upload_mask_res:
  108. upload_mask_key = upload_mask_res.json()['data']['image_key']
  109. # 更新card
  110. updated_mask_card ={
  111. "msg_type": "interactive",
  112. "content": json.dumps(update_process_card(reply_message_id, img_key, upload_mask_key))
  113. }
  114. update_message(reply_message_id, updated_mask_card)
  115. # 开始换装
  116. img_output_path = img_2_img(get_prompt(reply_message_id), mask, img_key, img_path, output_directory)
  117. # 获取img失败则上传失败图片
  118. if not img_output_path:
  119. img_output_path = "asset/img_fail.png"
  120. mark_img_fail(reply_message_id, img_key)
  121. # 上传图片
  122. output_img = get_img_file(img_output_path)
  123. upload_output_img_res = upload_image(img_key, output_img, "output_img")
  124. if upload_output_img_res:
  125. mark_img_done(reply_message_id, img_key)
  126. upload_output_img_key = upload_output_img_res.json()['data']['image_key']
  127. # 更新card
  128. upload_output_img_card ={
  129. "msg_type": "interactive",
  130. "content": json.dumps(update_process_card(reply_message_id, img_key, upload_output_img_key))
  131. }
  132. update_message(reply_message_id, upload_output_img_card)
  133. else:
  134. mark_img_fail(reply_message_id, img_key)
  135. else:
  136. mark_img_fail(reply_message_id, img_key)
  137. # 对单个图片做furry操作
  138. def img_furry(reply_message_id, img_key):
  139. img_path = f"data/input_images/{reply_message_id}/{img_key}/{img_key}.jpg"
  140. output_directory = f"data/input_images/{reply_message_id}/{img_key}"
  141. # 开始furry
  142. img_output_path = img_2_furry(get_prompt(reply_message_id), img_key, img_path, output_directory)
  143. # 获取img失败则上传失败图片
  144. if not img_output_path:
  145. img_output_path = "asset/img_fail.png"
  146. mark_img_fail(reply_message_id, img_key)
  147. # 上传图片
  148. output_img = get_img_file(img_output_path)
  149. upload_output_img_res = upload_image(img_key, output_img, "output_img")
  150. if upload_output_img_res:
  151. mark_img_done(reply_message_id, img_key)
  152. upload_output_img_key = upload_output_img_res.json()['data']['image_key']
  153. # 更新card
  154. upload_output_img_card ={
  155. "msg_type": "interactive",
  156. "content": json.dumps(update_process_card(reply_message_id, img_key, upload_output_img_key))
  157. }
  158. update_message(reply_message_id, upload_output_img_card)
  159. else:
  160. mark_img_fail(reply_message_id, img_key)
  161. # 配置prompt
  162. def get_prompt(reply_message_id):
  163. config = load_config_from_json(reply_message_id)
  164. return "niji_style,1girl,simple background,furry," + config['animal'] + ",Disney style,movie lighting,vhibi,benhance,Octane Render,cute anime,chinoiserie painting,8K,3D,C4D,super details,best quality,"
  165. # 用户选择卡片
  166. def choose_updated_card(message_id, config_data):
  167. card_content = load_template('start_card', message_id)
  168. card_content['elements'][3]['text']['content'] = get_prompt(message_id)
  169. image_columns = []
  170. img_keys = get_img_key_from_input_images(message_id)
  171. for img_key in img_keys:
  172. image_columns.append(
  173. {
  174. "tag": "column",
  175. "width": "weighted",
  176. "weight": 1,
  177. "vertical_align": "top",
  178. "elements":
  179. [
  180. {
  181. "tag": "img",
  182. "img_key": img_key,
  183. "alt": {
  184. "tag": "plain_text",
  185. "content": ""
  186. },
  187. "mode": "fit_horizontal",
  188. "preview": True
  189. },
  190. ]
  191. }
  192. )
  193. card_content['elements'][1]['columns'] = image_columns
  194. if not config_data['if_download']:
  195. card_content['elements'][4]['actions'][2] = {
  196. "tag": "button",
  197. "text": {
  198. "tag": "lark_md",
  199. "content": "**上传中**"
  200. },
  201. "type": "default",
  202. "value": {
  203. "type" : "uploading",
  204. "key": "uploading"
  205. }
  206. }
  207. return card_content
  208. # 用户等待卡片
  209. def waiting_card(message_id, config_data):
  210. card_content = load_template('waiting_card')
  211. card_content['elements'][3]['text']['content'] = get_prompt(message_id)
  212. image_columns = []
  213. img_keys = get_img_key_from_input_images(message_id)
  214. for img_key in img_keys:
  215. image_columns.append(
  216. {
  217. "tag": "column",
  218. "width": "weighted",
  219. "weight": 1,
  220. "vertical_align": "top",
  221. "elements":
  222. [
  223. {
  224. "tag": "img",
  225. "img_key": img_key,
  226. "alt": {
  227. "tag": "plain_text",
  228. "content": "test"
  229. },
  230. "mode": "fit_horizontal",
  231. "preview": True
  232. },
  233. ]
  234. }
  235. )
  236. card_content['elements'][1]['columns'] = image_columns
  237. save_process_card(card_content, message_id)
  238. return card_content