|
@@ -0,0 +1,204 @@
|
|
|
+import json
|
|
|
+import threading
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
+from api.img2img import img_2_img
|
|
|
+from api.segment import get_segment_mask
|
|
|
+from utils.config_io import load_config_from_json, save_config_to_json
|
|
|
+from utils.image_io import get_img_file, get_img_key_from_input_images
|
|
|
+from utils.request_api import reply_message, update_message, upload_image
|
|
|
+from utils.logger import logger
|
|
|
+from utils.template_io import load_template, save_process_card, update_process_card
|
|
|
+
|
|
|
+
|
|
|
+ACTION_TO_TITLE = {
|
|
|
+ 'clothes': '换装',
|
|
|
+ 'furry': 'Furry头像'
|
|
|
+}
|
|
|
+
|
|
|
+def handle_card(card_data):
|
|
|
+ reply_message_id = card_data['open_message_id']
|
|
|
+ action = card_data['action']
|
|
|
+ if action['tag'] == 'select_static':
|
|
|
+ # 用户进行选择
|
|
|
+ config_data = load_config_from_json(reply_message_id)
|
|
|
+ config_data[action['value']['type']] = action['option']
|
|
|
+ save_config_to_json(config_data, reply_message_id)
|
|
|
+ if config_data['action'] != "" and config_data['lora'] != "":
|
|
|
+ response_data = choose_updated_card(reply_message_id, config_data)
|
|
|
+ return response_data
|
|
|
+ else:
|
|
|
+ # 用户点击开始
|
|
|
+ # TODO: 保存任务状态实现断点继续
|
|
|
+ if action['value']['key'] == 'start':
|
|
|
+ config_data = load_config_from_json(reply_message_id)
|
|
|
+ if config_data['action'] != "" and config_data['lora'] != "":
|
|
|
+ response_data = waiting_card(reply_message_id, config_data)
|
|
|
+ executor = ThreadPoolExecutor()
|
|
|
+ executor.submit(start_sd_process, reply_message_id, config_data)
|
|
|
+ return response_data
|
|
|
+ else:
|
|
|
+ logger.info('用户点击了上传中按钮')
|
|
|
+
|
|
|
+
|
|
|
+def start_sd_process(reply_message_id, config_data):
|
|
|
+ if config_data['action'] == 'clothes':
|
|
|
+ # 获取segment mask
|
|
|
+ start_clothes_img_theards(reply_message_id)
|
|
|
+ else:
|
|
|
+ pass
|
|
|
+
|
|
|
+# 开始换装的每个图片的线程
|
|
|
+def start_clothes_img_theards(reply_message_id):
|
|
|
+ img_keys = get_img_key_from_input_images(reply_message_id)
|
|
|
+
|
|
|
+ config_data = load_config_from_json(reply_message_id)
|
|
|
+ config_data["img_missions"] = {}
|
|
|
+ for key in img_keys:
|
|
|
+ config_data["img_missions"][key] = "doing"
|
|
|
+ save_config_to_json(config_data, reply_message_id)
|
|
|
+
|
|
|
+ threads = []
|
|
|
+ for img_key in img_keys:
|
|
|
+ logger.info("开始clothes处理线程: "+img_key)
|
|
|
+ # img_clothes(reply_message_id, img_key)
|
|
|
+ t = threading.Thread(target=img_clothes, args=(reply_message_id, img_key))
|
|
|
+ threads.append(t)
|
|
|
+ t.start()
|
|
|
+ for t in threads:
|
|
|
+ t.join()
|
|
|
+
|
|
|
+def mark_img_fail(reply_message_id, img_key):
|
|
|
+ config_data = load_config_from_json(reply_message_id)
|
|
|
+ config_data["img_missions"][img_key] = "fail"
|
|
|
+ save_config_to_json(config_data, reply_message_id)
|
|
|
+
|
|
|
+def mark_img_done(reply_message_id, img_key):
|
|
|
+ config_data = load_config_from_json(reply_message_id)
|
|
|
+ config_data["img_missions"][img_key] = "done"
|
|
|
+ save_config_to_json(config_data, reply_message_id)
|
|
|
+
|
|
|
+# 对单个图片做换装操作
|
|
|
+def img_clothes(reply_message_id, img_key):
|
|
|
+ img_path = f"data/input_images/{reply_message_id}/{img_key}/{img_key}.jpg"
|
|
|
+ output_directory = f"data/input_images/{reply_message_id}/{img_key}"
|
|
|
+ mask_path = get_segment_mask(img_key, img_path, output_directory)
|
|
|
+ # 获取mask失败则上传失败图片
|
|
|
+ if not mask_path:
|
|
|
+ mask_path = "asset/segment_fail.png"
|
|
|
+ mark_img_fail(reply_message_id, img_key)
|
|
|
+ # 上传mask
|
|
|
+ mask = get_img_file(mask_path)
|
|
|
+ upload_mask_res = upload_image(img_key, mask, "mask")
|
|
|
+ if upload_mask_res:
|
|
|
+ upload_mask_key = upload_mask_res.json()['data']['image_key']
|
|
|
+ # 更新card
|
|
|
+ updated_mask_card ={
|
|
|
+ "msg_type": "interactive",
|
|
|
+ "content": json.dumps(update_process_card(reply_message_id, img_key, upload_mask_key))
|
|
|
+ }
|
|
|
+ update_message(reply_message_id, updated_mask_card)
|
|
|
+ # 开始换装
|
|
|
+ img_output_path = img_2_img(get_prompt(reply_message_id), mask, img_key, img_path, output_directory)
|
|
|
+ # 获取img失败则上传失败图片
|
|
|
+ if not img_output_path:
|
|
|
+ img_output_path = "asset/img_fail.png"
|
|
|
+ mark_img_fail(reply_message_id, img_key)
|
|
|
+ # 上传图片
|
|
|
+ output_img = get_img_file(img_output_path)
|
|
|
+ upload_output_img_res = upload_image(img_key, output_img, "output_img")
|
|
|
+ if upload_output_img_res:
|
|
|
+ mark_img_done(reply_message_id, img_key)
|
|
|
+ upload_output_img_key = upload_output_img_res.json()['data']['image_key']
|
|
|
+ # 更新card
|
|
|
+ upload_output_img_card ={
|
|
|
+ "msg_type": "interactive",
|
|
|
+ "content": json.dumps(update_process_card(reply_message_id, img_key, upload_output_img_key))
|
|
|
+ }
|
|
|
+ update_message(reply_message_id, upload_output_img_card)
|
|
|
+ else:
|
|
|
+ mark_img_fail(reply_message_id, img_key)
|
|
|
+ else:
|
|
|
+ mark_img_fail(reply_message_id, img_key)
|
|
|
+
|
|
|
+
|
|
|
+# 配置prompt
|
|
|
+def get_prompt(reply_message_id):
|
|
|
+ config = load_config_from_json(reply_message_id)
|
|
|
+ return "1girl, masterpiece, best quality, " + config['lora']
|
|
|
+
|
|
|
+# 用户选择卡片
|
|
|
+def choose_updated_card(message_id, config_data):
|
|
|
+ card_content = load_template('start_card', message_id)
|
|
|
+ card_content['elements'][3]['text']['content'] = get_prompt(message_id)
|
|
|
+
|
|
|
+ image_columns = []
|
|
|
+ img_keys = get_img_key_from_input_images(message_id)
|
|
|
+ for img_key in img_keys:
|
|
|
+ image_columns.append(
|
|
|
+ {
|
|
|
+ "tag": "column",
|
|
|
+ "width": "weighted",
|
|
|
+ "weight": 1,
|
|
|
+ "vertical_align": "top",
|
|
|
+ "elements":
|
|
|
+ [
|
|
|
+ {
|
|
|
+ "tag": "img",
|
|
|
+ "img_key": img_key,
|
|
|
+ "alt": {
|
|
|
+ "tag": "plain_text",
|
|
|
+ "content": ""
|
|
|
+ },
|
|
|
+ "mode": "fit_horizontal",
|
|
|
+ "preview": True
|
|
|
+ },
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ )
|
|
|
+ card_content['elements'][1]['columns'] = image_columns
|
|
|
+ if not config_data['if_download']:
|
|
|
+ card_content['elements'][4]['actions'][2] = {
|
|
|
+ "tag": "button",
|
|
|
+ "text": {
|
|
|
+ "tag": "lark_md",
|
|
|
+ "content": "**上传中**"
|
|
|
+ },
|
|
|
+ "type": "default",
|
|
|
+ "value": {
|
|
|
+ "type" : "uploading",
|
|
|
+ "key": "uploading"
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return card_content
|
|
|
+
|
|
|
+# 用户等待卡片
|
|
|
+def waiting_card(message_id, config_data):
|
|
|
+ card_content = load_template('waiting_card')
|
|
|
+ card_content['elements'][3]['text']['content'] = get_prompt(message_id)
|
|
|
+ image_columns = []
|
|
|
+ img_keys = get_img_key_from_input_images(message_id)
|
|
|
+ for img_key in img_keys:
|
|
|
+ image_columns.append(
|
|
|
+ {
|
|
|
+ "tag": "column",
|
|
|
+ "width": "weighted",
|
|
|
+ "weight": 1,
|
|
|
+ "vertical_align": "top",
|
|
|
+ "elements":
|
|
|
+ [
|
|
|
+ {
|
|
|
+ "tag": "img",
|
|
|
+ "img_key": img_key,
|
|
|
+ "alt": {
|
|
|
+ "tag": "plain_text",
|
|
|
+ "content": "test"
|
|
|
+ },
|
|
|
+ "mode": "fit_horizontal",
|
|
|
+ "preview": True
|
|
|
+ },
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ )
|
|
|
+ card_content['elements'][1]['columns'] = image_columns
|
|
|
+ save_process_card(card_content, message_id)
|
|
|
+ return card_content
|