import json import threading from concurrent.futures import ThreadPoolExecutor from api.img2img import img_2_img from api.segment import get_segment_mask from utils.config_io import load_config_from_json, save_config_to_json from utils.image_io import get_img_file, get_img_key_from_input_images from utils.request_api import reply_message, update_message, upload_image from utils.logger import logger from utils.template_io import load_template, save_process_card, update_process_card ACTION_TO_TITLE = { 'clothes': '换装', 'furry': 'Furry头像' } clothes_lock = threading.Lock() def handle_card(card_data): reply_message_id = card_data['open_message_id'] action = card_data['action'] if action['tag'] == 'select_static': # 用户进行选择 config_data = load_config_from_json(reply_message_id) config_data[action['value']['type']] = action['option'] save_config_to_json(config_data, reply_message_id) if config_data['action'] != "" and config_data['lora'] != "": response_data = choose_updated_card(reply_message_id, config_data) return response_data else: # 用户点击开始 # TODO: 保存任务状态实现断点继续 if action['value']['key'] == 'start': config_data = load_config_from_json(reply_message_id) if config_data['action'] != "" and config_data['lora'] != "": response_data = waiting_card(reply_message_id, config_data) executor = ThreadPoolExecutor() executor.submit(start_sd_process, reply_message_id, config_data) return response_data else: logger.info('用户点击了上传中按钮') def start_sd_process(reply_message_id, config_data): if config_data['action'] == 'clothes': # 获取segment mask start_clothes_img_theards(reply_message_id) else: pass # 开始换装的每个图片的线程 def start_clothes_img_theards(reply_message_id): img_keys = get_img_key_from_input_images(reply_message_id) config_data = load_config_from_json(reply_message_id) config_data["img_missions"] = {} for key in img_keys: config_data["img_missions"][key] = "doing" save_config_to_json(config_data, reply_message_id) threads = [] for img_key in img_keys: logger.info("开始clothes处理线程: "+img_key) with clothes_lock: # 在锁的上下文中启动线程,确保一次只启动一个 t = threading.Thread(target=img_clothes, args=(reply_message_id, img_key)) threads.append(t) t.start() t.join() # 等待当前线程完成再启动下一个 def mark_img_fail(reply_message_id, img_key): config_data = load_config_from_json(reply_message_id) config_data["img_missions"][img_key] = "fail" save_config_to_json(config_data, reply_message_id) def mark_img_done(reply_message_id, img_key): config_data = load_config_from_json(reply_message_id) config_data["img_missions"][img_key] = "done" save_config_to_json(config_data, reply_message_id) # 对单个图片做换装操作 def img_clothes(reply_message_id, img_key): img_path = f"data/input_images/{reply_message_id}/{img_key}/{img_key}.jpg" output_directory = f"data/input_images/{reply_message_id}/{img_key}" mask_path = get_segment_mask(img_key, img_path, output_directory) # 获取mask失败则上传失败图片 if not mask_path: mask_path = "asset/segment_fail.png" mark_img_fail(reply_message_id, img_key) # 上传mask mask = get_img_file(mask_path) upload_mask_res = upload_image(img_key, mask, "mask") if upload_mask_res: upload_mask_key = upload_mask_res.json()['data']['image_key'] # 更新card updated_mask_card ={ "msg_type": "interactive", "content": json.dumps(update_process_card(reply_message_id, img_key, upload_mask_key)) } update_message(reply_message_id, updated_mask_card) # 开始换装 img_output_path = img_2_img(get_prompt(reply_message_id), mask, img_key, img_path, output_directory) # 获取img失败则上传失败图片 if not img_output_path: img_output_path = "asset/img_fail.png" mark_img_fail(reply_message_id, img_key) # 上传图片 output_img = get_img_file(img_output_path) upload_output_img_res = upload_image(img_key, output_img, "output_img") if upload_output_img_res: mark_img_done(reply_message_id, img_key) upload_output_img_key = upload_output_img_res.json()['data']['image_key'] # 更新card upload_output_img_card ={ "msg_type": "interactive", "content": json.dumps(update_process_card(reply_message_id, img_key, upload_output_img_key)) } update_message(reply_message_id, upload_output_img_card) else: mark_img_fail(reply_message_id, img_key) else: mark_img_fail(reply_message_id, img_key) # 配置prompt def get_prompt(reply_message_id): config = load_config_from_json(reply_message_id) return "1girl, masterpiece, best quality, " + config['lora'] # 用户选择卡片 def choose_updated_card(message_id, config_data): card_content = load_template('start_card', message_id) card_content['elements'][3]['text']['content'] = get_prompt(message_id) image_columns = [] img_keys = get_img_key_from_input_images(message_id) for img_key in img_keys: image_columns.append( { "tag": "column", "width": "weighted", "weight": 1, "vertical_align": "top", "elements": [ { "tag": "img", "img_key": img_key, "alt": { "tag": "plain_text", "content": "" }, "mode": "fit_horizontal", "preview": True }, ] } ) card_content['elements'][1]['columns'] = image_columns if not config_data['if_download']: card_content['elements'][4]['actions'][2] = { "tag": "button", "text": { "tag": "lark_md", "content": "**上传中**" }, "type": "default", "value": { "type" : "uploading", "key": "uploading" } } return card_content # 用户等待卡片 def waiting_card(message_id, config_data): card_content = load_template('waiting_card') card_content['elements'][3]['text']['content'] = get_prompt(message_id) image_columns = [] img_keys = get_img_key_from_input_images(message_id) for img_key in img_keys: image_columns.append( { "tag": "column", "width": "weighted", "weight": 1, "vertical_align": "top", "elements": [ { "tag": "img", "img_key": img_key, "alt": { "tag": "plain_text", "content": "test" }, "mode": "fit_horizontal", "preview": True }, ] } ) card_content['elements'][1]['columns'] = image_columns save_process_card(card_content, message_id) return card_content