123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- import json
- import threading
- from concurrent.futures import ThreadPoolExecutor
- from api.img2img import img_2_img
- from api.segment import get_segment_mask
- from utils.config_io import load_config_from_json, save_config_to_json
- from utils.image_io import get_img_file, get_img_key_from_input_images
- from utils.request_api import reply_message, update_message, upload_image
- from utils.logger import logger
- from utils.template_io import load_template, save_process_card, update_process_card
- ACTION_TO_TITLE = {
- 'clothes': '换装',
- 'furry': 'Furry头像'
- }
- clothes_lock = threading.Lock()
- def handle_card(card_data):
- reply_message_id = card_data['open_message_id']
- action = card_data['action']
- if action['tag'] == 'select_static':
- # 用户进行选择
- config_data = load_config_from_json(reply_message_id)
- config_data[action['value']['type']] = action['option']
- save_config_to_json(config_data, reply_message_id)
- if config_data['action'] != "" and config_data['lora'] != "":
- response_data = choose_updated_card(reply_message_id, config_data)
- return response_data
- else:
- # 用户点击开始
- # TODO: 保存任务状态实现断点继续
- if action['value']['key'] == 'start':
- config_data = load_config_from_json(reply_message_id)
- if config_data['action'] != "" and config_data['lora'] != "":
- response_data = waiting_card(reply_message_id, config_data)
- executor = ThreadPoolExecutor()
- executor.submit(start_sd_process, reply_message_id, config_data)
- return response_data
- else:
- logger.info('用户点击了上传中按钮')
- def start_sd_process(reply_message_id, config_data):
- if config_data['action'] == 'clothes':
- # 获取segment mask
- start_clothes_img_theards(reply_message_id)
- else:
- pass
- # 开始换装的每个图片的线程
- def start_clothes_img_theards(reply_message_id):
- img_keys = get_img_key_from_input_images(reply_message_id)
- config_data = load_config_from_json(reply_message_id)
- config_data["img_missions"] = {}
- for key in img_keys:
- config_data["img_missions"][key] = "doing"
- save_config_to_json(config_data, reply_message_id)
- threads = []
- for img_key in img_keys:
- logger.info("开始clothes处理线程: "+img_key)
- with clothes_lock:
- # 在锁的上下文中启动线程,确保一次只启动一个
- t = threading.Thread(target=img_clothes, args=(reply_message_id, img_key))
- threads.append(t)
- t.start()
- t.join() # 等待当前线程完成再启动下一个
- def mark_img_fail(reply_message_id, img_key):
- config_data = load_config_from_json(reply_message_id)
- config_data["img_missions"][img_key] = "fail"
- save_config_to_json(config_data, reply_message_id)
- def mark_img_done(reply_message_id, img_key):
- config_data = load_config_from_json(reply_message_id)
- config_data["img_missions"][img_key] = "done"
- save_config_to_json(config_data, reply_message_id)
- # 对单个图片做换装操作
- def img_clothes(reply_message_id, img_key):
- img_path = f"data/input_images/{reply_message_id}/{img_key}/{img_key}.jpg"
- output_directory = f"data/input_images/{reply_message_id}/{img_key}"
- mask_path = get_segment_mask(img_key, img_path, output_directory)
- # 获取mask失败则上传失败图片
- if not mask_path:
- mask_path = "asset/segment_fail.png"
- mark_img_fail(reply_message_id, img_key)
- # 上传mask
- mask = get_img_file(mask_path)
- upload_mask_res = upload_image(img_key, mask, "mask")
- if upload_mask_res:
- upload_mask_key = upload_mask_res.json()['data']['image_key']
- # 更新card
- updated_mask_card ={
- "msg_type": "interactive",
- "content": json.dumps(update_process_card(reply_message_id, img_key, upload_mask_key))
- }
- update_message(reply_message_id, updated_mask_card)
- # 开始换装
- img_output_path = img_2_img(get_prompt(reply_message_id), mask, img_key, img_path, output_directory)
- # 获取img失败则上传失败图片
- if not img_output_path:
- img_output_path = "asset/img_fail.png"
- mark_img_fail(reply_message_id, img_key)
- # 上传图片
- output_img = get_img_file(img_output_path)
- upload_output_img_res = upload_image(img_key, output_img, "output_img")
- if upload_output_img_res:
- mark_img_done(reply_message_id, img_key)
- upload_output_img_key = upload_output_img_res.json()['data']['image_key']
- # 更新card
- upload_output_img_card ={
- "msg_type": "interactive",
- "content": json.dumps(update_process_card(reply_message_id, img_key, upload_output_img_key))
- }
- update_message(reply_message_id, upload_output_img_card)
- else:
- mark_img_fail(reply_message_id, img_key)
- else:
- mark_img_fail(reply_message_id, img_key)
- # 配置prompt
- def get_prompt(reply_message_id):
- config = load_config_from_json(reply_message_id)
- return "1girl, masterpiece, best quality, " + config['lora']
- # 用户选择卡片
- def choose_updated_card(message_id, config_data):
- card_content = load_template('start_card', message_id)
- card_content['elements'][3]['text']['content'] = get_prompt(message_id)
- image_columns = []
- img_keys = get_img_key_from_input_images(message_id)
- for img_key in img_keys:
- image_columns.append(
- {
- "tag": "column",
- "width": "weighted",
- "weight": 1,
- "vertical_align": "top",
- "elements":
- [
- {
- "tag": "img",
- "img_key": img_key,
- "alt": {
- "tag": "plain_text",
- "content": ""
- },
- "mode": "fit_horizontal",
- "preview": True
- },
- ]
- }
- )
- card_content['elements'][1]['columns'] = image_columns
- if not config_data['if_download']:
- card_content['elements'][4]['actions'][2] = {
- "tag": "button",
- "text": {
- "tag": "lark_md",
- "content": "**上传中**"
- },
- "type": "default",
- "value": {
- "type" : "uploading",
- "key": "uploading"
- }
- }
- return card_content
- # 用户等待卡片
- def waiting_card(message_id, config_data):
- card_content = load_template('waiting_card')
- card_content['elements'][3]['text']['content'] = get_prompt(message_id)
- image_columns = []
- img_keys = get_img_key_from_input_images(message_id)
- for img_key in img_keys:
- image_columns.append(
- {
- "tag": "column",
- "width": "weighted",
- "weight": 1,
- "vertical_align": "top",
- "elements":
- [
- {
- "tag": "img",
- "img_key": img_key,
- "alt": {
- "tag": "plain_text",
- "content": "test"
- },
- "mode": "fit_horizontal",
- "preview": True
- },
- ]
- }
- )
- card_content['elements'][1]['columns'] = image_columns
- save_process_card(card_content, message_id)
- return card_content
|