Forráskód Böngészése

feat: first commit

Shellmiao 1 éve
commit
c10f054bdd

+ 5 - 0
.env

@@ -0,0 +1,5 @@
+BOT_AES_KEY=ti9qtV3qofRTYjIrN0clnhHcoQvpHXIP
+app_id=cli_a4352ec26db89013
+app_secret=EBzvyQmEx6WlqS8FkUimUhpNbFyhnwlF
+segment_url=https://u166586-afba-837b5bd8.westa.seetacloud.com:8443
+img_url=https://u166586-a3a7-76a835bb.westb.seetacloud.com:8443

+ 4 - 0
.gitignore

@@ -0,0 +1,4 @@
+*.log
+__pycache__
+*/__pycache__
+data

+ 2 - 0
README.md

@@ -0,0 +1,2 @@
+# Stable Diffusion 飞书Bot
+详见飞书开放平台文档

+ 64 - 0
api/img2img.py

@@ -0,0 +1,64 @@
+import json
+import os
+import requests
+import io
+import base64
+from PIL import Image, PngImagePlugin
+from dotenv import load_dotenv
+from utils.logger import logger
+
+load_dotenv() # 加载环境变量
+img_url = os.getenv('img_url')
+
+def img_2_img(prompt, mask_img, img_key, img_path, output_directory):
+    with open(img_path, 'rb') as f:
+        image_data = f.read()
+    input_image = base64.b64encode(image_data).decode('utf-8')
+    mask = base64.b64encode(mask_img).decode('utf-8')
+    payload = {
+        "init_images": [
+            input_image
+        ],
+        "denoising_strength": 0.96,
+        "mask": mask,
+        "prompt": prompt,
+        "negative_prompt": "EasyNegativeV2,(badhandv4:1.2)",
+        "batch_size": 1,
+        "inpainting_mask_invert": 1,
+        "steps": 30,
+        "cfg_scale": 7,
+        "sampler_index": "DPM++ 2M Karras",
+        "alwayson_scripts": {
+            "ADetailer": {
+                "args": [
+                    True,
+                    {
+                        "ad_model": "face_yolov8n.pt",
+                        "ad_prompt": prompt,
+                        "ad_negative_prompt": "EasyNegativeV2,(badhandv4:1.2)",
+                    }
+                ]
+            },
+            "controlnet": {
+                "args": [
+                    {
+                        "input_image": input_image,
+                        "module": "openpose_full",
+                        "model": "control_v11p_sd15_openpose [cab727d4]",
+                    }
+                ]
+            }
+        }
+    }
+    response = requests.post(url=f'{img_url}/sdapi/v1/img2img', json=payload)
+
+    r = response.json()
+    output_path = output_directory + f"/{img_key}-output.jpg"
+    try:
+        result = r['images'][0]
+        image = Image.open(io.BytesIO(base64.b64decode(result.split(",", 1)[0])))
+        image.save(output_path)
+        return output_path
+    except:
+        logger.error('img2img error' + str(r))
+        return None

+ 37 - 0
api/segment.py

@@ -0,0 +1,37 @@
+import json
+import os
+import requests
+import io
+import base64
+from PIL import Image, PngImagePlugin
+from dotenv import load_dotenv
+from utils.logger import logger
+
+load_dotenv() # 加载环境变量
+segment_url = os.getenv('segment_url')
+
+def get_segment_mask(img_key, img_path, output_directory):
+    with open(img_path, 'rb') as f:
+        image_data = f.read()
+    image_base64 = base64.b64encode(image_data).decode('utf-8')
+
+    payload = {
+    "sam_model_name": "sam_vit_h_4b8939.pth",
+    "input_image": image_base64,
+    "sam_positive_points": [],
+    "sam_negative_points": [],
+    "dino_enabled": True,
+    "dino_model_name": "GroundingDINO_SwinB (938MB)",
+    "dino_text_prompt": "clothes",
+    "dino_box_threshold": 0.2
+    }
+    response = requests.post(url=f'{segment_url}/sam/sam-predict', json=payload)
+    output_path = output_directory + f"/{img_key}-mask.jpg"
+    r=response.json()
+    try:
+        image = Image.open(io.BytesIO(base64.b64decode(r['masks'][2].split(",", 1)[0])))
+        image.save(output_path)
+        return output_path
+    except:
+        logger.error('segment error' + str(r))
+        return None

BIN
asset/img_fail.png


BIN
asset/segment_fail.png


+ 204 - 0
handler/card_handler.py

@@ -0,0 +1,204 @@
+import json
+import threading
+from concurrent.futures import ThreadPoolExecutor
+from api.img2img import img_2_img
+from api.segment import get_segment_mask
+from utils.config_io import load_config_from_json, save_config_to_json
+from utils.image_io import get_img_file, get_img_key_from_input_images
+from utils.request_api import reply_message, update_message, upload_image
+from utils.logger import logger
+from utils.template_io import load_template, save_process_card, update_process_card
+
+
+ACTION_TO_TITLE = {
+    'clothes': '换装',
+    'furry': 'Furry头像'
+}
+
+def handle_card(card_data):
+    reply_message_id = card_data['open_message_id']
+    action = card_data['action']
+    if action['tag'] == 'select_static':
+        # 用户进行选择
+        config_data = load_config_from_json(reply_message_id)
+        config_data[action['value']['type']] = action['option']
+        save_config_to_json(config_data, reply_message_id)
+        if config_data['action'] != "" and config_data['lora'] != "":
+            response_data = choose_updated_card(reply_message_id, config_data)
+            return response_data
+    else:
+        # 用户点击开始
+        # TODO: 保存任务状态实现断点继续
+        if action['value']['key'] == 'start':
+            config_data = load_config_from_json(reply_message_id)
+            if config_data['action'] != "" and config_data['lora'] != "":
+                response_data = waiting_card(reply_message_id, config_data)
+                executor = ThreadPoolExecutor()
+                executor.submit(start_sd_process, reply_message_id, config_data)
+                return response_data
+        else:
+            logger.info('用户点击了上传中按钮')
+
+
+def start_sd_process(reply_message_id, config_data):
+    if config_data['action'] == 'clothes':
+        # 获取segment mask
+        start_clothes_img_theards(reply_message_id)
+    else:
+        pass
+
+# 开始换装的每个图片的线程
+def start_clothes_img_theards(reply_message_id):
+    img_keys = get_img_key_from_input_images(reply_message_id)
+
+    config_data = load_config_from_json(reply_message_id)
+    config_data["img_missions"] = {}
+    for key in img_keys:
+        config_data["img_missions"][key] = "doing"
+    save_config_to_json(config_data, reply_message_id)
+
+    threads = []
+    for img_key in img_keys:
+        logger.info("开始clothes处理线程: "+img_key)
+        # img_clothes(reply_message_id, img_key)
+        t = threading.Thread(target=img_clothes, args=(reply_message_id, img_key))
+        threads.append(t)
+        t.start()
+    for t in threads:
+        t.join()
+
+def mark_img_fail(reply_message_id, img_key):
+    config_data = load_config_from_json(reply_message_id)
+    config_data["img_missions"][img_key] = "fail"
+    save_config_to_json(config_data, reply_message_id)
+
+def mark_img_done(reply_message_id, img_key):
+    config_data = load_config_from_json(reply_message_id)
+    config_data["img_missions"][img_key] = "done"
+    save_config_to_json(config_data, reply_message_id)
+
+# 对单个图片做换装操作
+def img_clothes(reply_message_id, img_key):
+    img_path = f"data/input_images/{reply_message_id}/{img_key}/{img_key}.jpg"
+    output_directory = f"data/input_images/{reply_message_id}/{img_key}"
+    mask_path = get_segment_mask(img_key, img_path, output_directory)
+    # 获取mask失败则上传失败图片
+    if not mask_path:
+        mask_path = "asset/segment_fail.png"
+        mark_img_fail(reply_message_id, img_key)
+    # 上传mask
+    mask = get_img_file(mask_path)
+    upload_mask_res = upload_image(img_key, mask, "mask")
+    if upload_mask_res:
+        upload_mask_key = upload_mask_res.json()['data']['image_key']
+        # 更新card
+        updated_mask_card ={
+            "msg_type": "interactive",
+            "content": json.dumps(update_process_card(reply_message_id, img_key, upload_mask_key))
+        }
+        update_message(reply_message_id, updated_mask_card)
+        # 开始换装
+        img_output_path = img_2_img(get_prompt(reply_message_id), mask, img_key, img_path, output_directory)
+        # 获取img失败则上传失败图片
+        if not img_output_path:
+            img_output_path = "asset/img_fail.png"
+            mark_img_fail(reply_message_id, img_key)
+        # 上传图片
+        output_img = get_img_file(img_output_path)
+        upload_output_img_res = upload_image(img_key, output_img, "output_img")
+        if upload_output_img_res:
+            mark_img_done(reply_message_id, img_key)
+            upload_output_img_key = upload_output_img_res.json()['data']['image_key']
+            # 更新card
+            upload_output_img_card ={
+                "msg_type": "interactive",
+                "content": json.dumps(update_process_card(reply_message_id, img_key, upload_output_img_key))
+            }
+            update_message(reply_message_id, upload_output_img_card)
+        else:
+            mark_img_fail(reply_message_id, img_key)
+    else:
+        mark_img_fail(reply_message_id, img_key)
+
+
+# 配置prompt
+def get_prompt(reply_message_id):
+    config = load_config_from_json(reply_message_id)
+    return "1girl, masterpiece, best quality, " + config['lora']
+
+# 用户选择卡片
+def choose_updated_card(message_id, config_data):
+    card_content = load_template('start_card', message_id)
+    card_content['elements'][3]['text']['content'] = get_prompt(message_id)
+
+    image_columns = []
+    img_keys = get_img_key_from_input_images(message_id)
+    for img_key in img_keys:
+        image_columns.append(
+            {
+                "tag": "column",
+                "width": "weighted",
+                "weight": 1,
+                "vertical_align": "top",
+                "elements": 
+                [
+                    {
+                        "tag": "img",
+                        "img_key": img_key,
+                        "alt": {
+                            "tag": "plain_text",
+                            "content": ""
+                        },
+                        "mode": "fit_horizontal",
+                        "preview": True
+                    },
+                ]
+            }
+        )
+    card_content['elements'][1]['columns'] = image_columns
+    if not config_data['if_download']:
+        card_content['elements'][4]['actions'][2] = {
+            "tag": "button",
+            "text": {
+                "tag": "lark_md",
+                "content": "**上传中**"
+            },
+            "type": "default",
+            "value": {
+                "type" : "uploading",
+                "key": "uploading"
+            }
+        }
+    return card_content
+
+# 用户等待卡片
+def waiting_card(message_id, config_data):
+    card_content = load_template('waiting_card')
+    card_content['elements'][3]['text']['content'] = get_prompt(message_id)
+    image_columns = []
+    img_keys = get_img_key_from_input_images(message_id)
+    for img_key in img_keys:
+        image_columns.append(
+            {
+                "tag": "column",
+                "width": "weighted",
+                "weight": 1,
+                "vertical_align": "top",
+                "elements": 
+                [
+                    {
+                        "tag": "img",
+                        "img_key": img_key,
+                        "alt": {
+                            "tag": "plain_text",
+                            "content": "test"
+                        },
+                        "mode": "fit_horizontal",
+                        "preview": True
+                    },
+                ]
+            }
+        )
+    card_content['elements'][1]['columns'] = image_columns
+    save_process_card(card_content, message_id)
+    return card_content

+ 125 - 0
handler/meg_handler.py

@@ -0,0 +1,125 @@
+import json
+from handler.card_handler import choose_updated_card
+from utils.config_io import add_to_user_reply, if_has_processed, load_config_from_json, save_config_to_json
+from utils.image_io import create_image_path, save_image
+from utils.request_api import get_img, reply_message, update_message
+from utils.logger import logger
+from utils.template_io import load_template
+
+
+def handle_meg(meg, event_id, user_open_id):
+    if 'message_id' in meg and 'content' in meg:
+        # 正常消息,保存message_id可用于回复
+        message_id = meg['message_id']
+        if not if_has_processed(event_id):
+            content = json.loads(meg['content'])
+            if 'content' in content:
+                # 为富文本消息
+                rich_text_content = content['content']
+                if if_at_bot_rich_text(rich_text_content):
+                    # @了机器人
+                    if if_has_img(rich_text_content):
+                        # 有图片
+                        response_data = img_in_meg_card(rich_text_content, if_has_download = False)
+                        res = reply_message(message_id, response_data)
+                        if res:
+                            reply_msg_id = res.json()['data']['message_id']
+                            # 保存到全局config防止重复回复
+                            add_to_user_reply(event_id, message_id, reply_msg_id)
+                            # 保存user信息到msg_config
+                            config_data = load_config_from_json(reply_msg_id)
+                            config_data['user_message'] = message_id
+                            config_data['user_open_id'] = user_open_id
+                            save_config_to_json(config_data, reply_msg_id)
+                            # 下载图片
+                            download_img_in_content(message_id, reply_msg_id, rich_text_content)
+                            # 更新卡片
+                            config_data = load_config_from_json(reply_msg_id)
+                            config_data['if_download'] = True
+                            save_config_to_json(config_data, reply_msg_id)
+                            if config_data['action'] != "" and config_data['lora'] != "":
+                                update_card_data = choose_updated_card(reply_msg_id, config_data)
+                                update_data ={
+                                    "msg_type": "interactive",
+                                }
+                                update_data['content'] = json.dumps(update_card_data)
+                            else:
+                                update_data = img_in_meg_card(rich_text_content, if_has_download = True, reply_message_id = reply_msg_id)
+                            update_message(reply_msg_id, update_data)
+                else:
+                    # 未@机器人
+                    return
+        else:
+            # 处理过的消息
+            logger.info(message_id + '已处理过')
+            return
+    else:
+        # 非正常消息
+        return
+
+# 判断消息中是否@了机器人
+def if_at_bot_rich_text(content):
+    for section in content:
+        for item in section:
+            if item.get('tag') == 'at' and item.get('user_name') == 'Stable Diffusion Of HOXI':
+                return True
+    return False
+
+# 判断消息中是否有图片
+def if_has_img(content):
+    for section in content:
+        for item in section:
+            if item.get('tag') == 'img':
+                return True
+    return False
+
+# 下载用户信息中的图片
+def download_img_in_content(message_id, reply_message_id, content):
+    for section in content:
+        for item in section:
+            if item.get('tag') == 'img':
+                img_key = item.get('image_key')
+                create_image_path(reply_message_id, img_key)
+    for section in content:
+        for item in section:
+            if item.get('tag') == 'img':
+                img_key = item.get('image_key')
+                response = get_img(message_id, img_key)
+                if response.status_code == 200:
+                    save_image(response.content, reply_message_id, img_key)
+                    logger.info(img_key + "保存成功")
+
+# 生成交互卡片消息
+def img_in_meg_card(content, if_has_download = False, reply_message_id = None):
+    response_data ={
+        "msg_type": "interactive",
+    }
+    card_content = load_template('choose_card' if if_has_download else 'upload_card', reply_message_id)
+    image_columns = []
+    for section in content:
+        for item in section:
+            if item.get('tag') == 'img':
+                image_columns.append(
+                    {
+                        "tag": "column",
+                        "width": "weighted",
+                        "weight": 1,
+                        "vertical_align": "top",
+                        "elements": 
+                        [
+                            {
+                                "tag": "img",
+                                "img_key": item.get('image_key'),
+                                "alt": {
+                                    "tag": "plain_text",
+                                    "content": ""
+                                },
+                                "mode": "fit_horizontal",
+                                "preview": True
+                            },
+                        ]
+                    }
+                )
+    card_content['elements'][1]['columns'] = image_columns
+    response_data['content'] = json.dumps(card_content)
+    return response_data

+ 5 - 0
lora_options.json

@@ -0,0 +1,5 @@
+{
+    "刘亦菲": "<lora:liuyifei_reg-000018:0.8>",
+    "INS-1": "<lora:ins_test_2:1>",
+    "CHN_002": " <lora:CNH_002:0.8>"
+}

+ 65 - 0
main.py

@@ -0,0 +1,65 @@
+import json
+import os
+from flask import Flask, request, jsonify
+from handler.card_handler import handle_card
+from handler.meg_handler import handle_meg
+from utils.logger import logger
+from utils.decrypt import AESCipher
+from dotenv import load_dotenv
+from functools import wraps
+
+
+load_dotenv()  # 加载环境变量
+
+app = Flask(__name__)
+cipher = AESCipher(os.getenv('BOT_AES_KEY')) 
+
+def decrypt_request(f):
+    @wraps(f)
+    def decorated_function(*args, **kwargs):
+        if request.json and 'encrypt' in request.json:
+            encrypt = request.json['encrypt']
+            decrypted_request = json.loads(cipher.decrypt_string(encrypt))
+            return f(decrypted_request, *args, **kwargs)
+        else:
+            logger.error("/enc_req无法解密: " + request.json)
+    return decorated_function
+
+@app.route('/enc_req', methods=['POST'])
+@decrypt_request
+def encrypt_challenge(decrypted_request):
+    if 'challenge' in decrypted_request:
+        response = {
+            "challenge": decrypted_request['challenge']
+        }
+        return jsonify(response), 200
+    elif 'event_type' in decrypted_request['header']:
+        if decrypted_request['header']['event_type'] == 'im.message.receive_v1':
+            meg = decrypted_request['event']['message']
+            event_id = decrypted_request['header']['event_id']
+            user_open_id = decrypted_request['event']['sender']['sender_id']['open_id']
+            handle_meg(meg, event_id, user_open_id)
+        elif decrypted_request['header']['event_type'] == 'application.bot.menu_v6':
+            logger.info(decrypted_request)
+        return "", 200
+    else:
+        logger.warn("未知事件: " + decrypted_request)
+        return "", 200
+
+@app.route('/req', methods=['POST'])
+def challenge():
+    if 'challenge' in request.json:
+        response = {
+            "challenge": request.json['challenge']
+        }
+        return jsonify(response), 200
+    elif 'token' in request.json and 'action' in request.json:
+        # 卡片事件
+        card_response = handle_card(request.json)
+        return jsonify(card_response), 200
+    else:
+        logger.info(request.json)
+        return "", 200
+
+if __name__ == '__main__':
+    app.run(host='0.0.0.0', port=7863, debug=True)

+ 14 - 0
model/global_model.py

@@ -0,0 +1,14 @@
+from utils.singleton import Singleton
+
+@Singleton
+class GlobalModel:
+    def __init__(self):
+        self._tenant_token = ""
+
+    @property
+    def tenant_token(self):
+        return self._tenant_token
+    
+    @tenant_token.setter
+    def tenant_token(self, value):
+        self._tenant_token = value

+ 4 - 0
requirement.txt

@@ -0,0 +1,4 @@
+pycryptodome
+flask
+python-dotenv
+requests_toolbelt

+ 97 - 0
template/choose_card.json

@@ -0,0 +1,97 @@
+{
+    "header": {
+        "title": {
+            "tag": "plain_text",
+            "content": "上传图片成功,请选择你的操作后点击开始"
+        },
+        "template": "blue"
+    },
+    "elements": [
+        {
+            "tag": "div",
+            "text": {
+                "content": "ヾ(≧▽≦*)o",
+                "tag": "plain_text"
+            }
+        },
+        {
+            "tag": "column_set",
+            "flex_mode": "none",
+            "background_style": "grey",
+            "columns": []
+        },
+        {
+            "tag": "hr"
+        },
+        {
+            "tag": "action",
+            "actions": [
+                {
+                    "tag": "select_static",
+                    "placeholder": {
+                        "tag": "plain_text",
+                        "content": "选择操作"
+                    },
+                    "initial_option": "",
+                    "value": {
+                        "type": "action"
+                    },
+                    "options": [
+                        {
+                            "text": {
+                                "tag": "lark_md",
+                                "content": "试衣间"
+                            },
+                            "value": "clothes"
+                        },
+                        {
+                            "text": {
+                                "tag": "lark_md",
+                                "content": "Furry头像"
+                            },
+                            "value": "furry"
+                        }
+                        
+                    ]
+                },
+                {
+                    "tag": "select_static",
+                    "placeholder": {
+                        "tag": "plain_text",
+                        "content": "选择Lora人物"
+                    },
+                    "initial_option": "",
+                    "value": {
+                        "type": "lora"
+                    },
+                    "options": []
+                },
+                {
+                    "tag": "button",
+                    "text": {
+                        "tag": "lark_md",
+                        "content": "**开始**"
+                    },
+                    "type": "primary",
+                    "value": {
+                        "type" : "action",
+                        "key": "start"
+                    }
+                }
+            ]
+        },
+        {
+            "tag": "note",
+            "elements": [
+                {
+                    "tag": "plain_text",
+                    "content": "图片需要有模特( •̀ ω •́ )✧"
+                }
+            ]
+        }
+    ],
+    "config": {
+        "enable_forward": true,
+        "update_multi":true
+    }
+}

+ 92 - 0
template/start_card.json

@@ -0,0 +1,92 @@
+{
+    "header": {
+        "title": {
+            "tag": "plain_text",
+            "content": "操作选择成功,请确认Prompt后点击开始"
+        },
+        "template": "blue"
+    },
+    "elements": [
+        {
+            "tag": "div",
+            "text": {
+                "content": "ヾ(≧▽≦*)o",
+                "tag": "plain_text"
+            }
+        },
+        {
+            "tag": "column_set",
+            "flex_mode": "none",
+            "background_style": "grey",
+            "columns": []
+        },
+        {
+            "tag": "hr"
+        },
+        {
+            "tag": "div",
+            "text": {
+                "content": "",
+                "tag": "plain_text"
+            }
+        },
+        {
+            "tag": "action",
+            "actions": [
+                {
+                    "tag": "select_static",
+                    "initial_option": "",
+                    "value": {
+                        "type": "action"
+                    },
+                    "options": [
+                        {
+                            "text": {
+                                "tag": "lark_md",
+                                "content": "试衣间"
+                            },
+                            "value": "clothes"
+                        },
+                        {
+                            "text": {
+                                "tag": "lark_md",
+                                "content": "Furry头像"
+                            },
+                            "value": "furry"
+                        }
+                        
+                    ]
+                },
+                {
+                    "tag": "select_static",
+                    "initial_option": "",
+                    "value": {
+                        "type": "lora"
+                    },
+                    "options": []
+                },
+                {
+                    "tag": "button",
+                    "text": {
+                        "tag": "lark_md",
+                        "content": "**开始**"
+                    },
+                    "type": "primary",
+                    "value": {
+                        "type" : "action",
+                        "key": "start"
+                    }
+                }
+            ]
+        },
+        {
+            "tag": "note",
+            "elements": [
+                {
+                    "tag": "plain_text",
+                    "content": "图片需要有模特( •̀ ω •́ )✧"
+                }
+            ]
+        }
+    ]
+}

+ 95 - 0
template/upload_card.json

@@ -0,0 +1,95 @@
+{
+    "header": {
+        "title": {
+            "tag": "plain_text",
+            "content": "上传图片中,请先选择你的操作"
+        },
+        "template": "blue"
+    },
+    "elements": [
+        {
+            "tag": "div",
+            "text": {
+                "content": "------------上传中(o゚v゚)ノ",
+                "tag": "plain_text"
+            }
+        },
+        {
+            "tag": "column_set",
+            "flex_mode": "none",
+            "background_style": "grey",
+            "columns": []
+        },
+        {
+            "tag": "hr"
+        },
+        {
+            "tag": "action",
+            "actions": [
+                {
+                    "tag": "select_static",
+                    "placeholder": {
+                        "tag": "plain_text",
+                        "content": "选择操作"
+                    },
+                    "value": {
+                        "type": "action"
+                    },
+                    "options": [
+                        {
+                            "text": {
+                                "tag": "lark_md",
+                                "content": "试衣间"
+                            },
+                            "value": "clothes"
+                        },
+                        {
+                            "text": {
+                                "tag": "lark_md",
+                                "content": "Furry头像"
+                            },
+                            "value": "furry"
+                        }
+                        
+                    ]
+                },
+                {
+                    "tag": "select_static",
+                    "placeholder": {
+                        "tag": "plain_text",
+                        "content": "选择Lora人物"
+                    },
+                    "value": {
+                        "type": "lora"
+                    },
+                    "options": []
+                },
+                {
+                    "tag": "button",
+                    "text": {
+                        "tag": "lark_md",
+                        "content": "**上传中**"
+                    },
+                    "type": "default",
+                    "value": {
+                        "type" : "uploading",
+                        "key": "uploading"
+                    }
+                }
+            ]
+        },
+        {
+            "tag": "note",
+            "elements": [
+                {
+                    "tag": "plain_text",
+                    "content": "图片需要有模特( •̀ ω •́ )✧"
+                }
+            ]
+        }
+    ],
+    "config": {
+        "enable_forward": true,
+        "update_multi":true
+    }
+}

+ 50 - 0
template/waiting_card.json

@@ -0,0 +1,50 @@
+{
+    "header": {
+        "title": {
+            "tag": "plain_text",
+            "content": "任务已加入到队列,请等待..."
+        },
+        "template": "blue"
+    },
+    "elements": [
+        {
+            "tag": "div",
+            "text": {
+                "content": "---------------(o゜▽゜)o☆",
+                "tag": "plain_text"
+            }
+        },
+        {
+            "tag": "column_set",
+            "flex_mode": "none",
+            "background_style": "grey",
+            "columns": []
+        },
+        {
+            "tag": "hr"
+        },
+        {
+            "tag": "div",
+            "text": {
+                "content": "",
+                "tag": "plain_text"
+            }
+        },
+        {
+            "tag": "div",
+            "text": {
+                "content": "出图中,请等待...",
+                "tag": "plain_text"
+            }
+        },
+        {
+            "tag": "note",
+            "elements": [
+                {
+                    "tag": "plain_text",
+                    "content": "一般一张图5秒左右( •̀ ω •́ )✧"
+                }
+            ]
+        }
+    ]
+}

+ 96 - 0
utils/config_io.py

@@ -0,0 +1,96 @@
+import json
+import os
+from utils.logger import logger
+
+
+def save_config_to_json(config_data, message_id):
+    # 确保目录存在,如果不存在则创建
+    output_directory = f"data/input_images/{message_id}"
+    os.makedirs(output_directory, exist_ok=True)
+
+    # 生成JSON文件路径
+    json_path = os.path.join(output_directory, "config.json")
+
+    # 将字典转换为JSON并保存到文件中
+    with open(json_path, 'w') as f:
+        json.dump(config_data, f)
+
+def load_config_from_json(message_id):
+    # 生成JSON文件路径
+    json_path = os.path.join("data/input_images", message_id, "config.json")
+
+    # 检查文件是否存在,如果不存在则返回默认字典
+    if not os.path.exists(json_path):
+        return {"user_message": "", "action": "", "lora": "", "if_download": False, "img_missions": {}}
+
+    # 从文件中读取JSON数据并转换为字典
+    with open(json_path, 'r') as f:
+        config_data = json.load(f)
+
+    return config_data
+
+def add_to_user_reply(event_id, user_message_id, reply_message_id):
+    output_directory = "data"
+    os.makedirs(output_directory, exist_ok=True)
+    json_path = os.path.join(output_directory, "user_reply.json")
+
+    # 检查文件是否存在,如果不存在则创建
+    if os.path.exists(json_path):
+        # 从文件中读取JSON数据并转换为字典
+        with open(json_path, 'r') as f:
+            config_data = json.load(f)
+        if event_id in config_data:
+            return
+        else:
+            config_data[event_id] = [user_message_id, reply_message_id]
+            with open(json_path, 'w') as f:
+                json.dump(config_data, f)
+            return
+    else:
+        config_data={
+            event_id : [user_message_id, reply_message_id]
+        }
+        with open(json_path, 'w') as f:
+            json.dump(config_data, f)
+        return
+
+def if_has_processed(event_id):
+    output_directory = "data"
+    os.makedirs(output_directory, exist_ok=True)
+    json_path = os.path.join(output_directory, "user_reply.json")
+
+    # 检查文件是否存在,如果不存在则创建
+    if os.path.exists(json_path):
+        # 从文件中读取JSON数据并转换为字典
+        with open(json_path, 'r') as f:
+            config_data = json.load(f)
+        if event_id in config_data:
+            return True
+        else:
+            return False
+    else:
+        return False
+    
+# 从文件中读取lora选项
+def load_lora_options():
+    json_path = "lora_options.json"
+    with open(json_path, 'r', encoding='utf-8') as f:
+        lora_options = json.load(f)
+    return lora_options
+
+
+# 检测任务是否完成
+def if_mission_done(reply_message_id):
+    config_data = load_config_from_json(reply_message_id)
+    for key, value in config_data['img_missions'].items():
+        if value != 'done':
+            return False
+    return True
+
+# 检测任务是否存在失败
+def if_mission_fail(reply_message_id):
+    config_data = load_config_from_json(reply_message_id)
+    for key, value in config_data['img_missions'].items():
+        if value == 'fail':
+            return True
+    return False

+ 23 - 0
utils/decrypt.py

@@ -0,0 +1,23 @@
+import hashlib
+import base64
+from Crypto.Cipher import AES
+class  AESCipher(object):
+    def __init__(self, key):
+        self.bs = AES.block_size
+        self.key=hashlib.sha256(AESCipher.str_to_bytes(key)).digest()
+    @staticmethod
+    def str_to_bytes(data):
+        u_type = type(b"".decode('utf8'))
+        if isinstance(data, u_type):
+            return data.encode('utf8')
+        return data
+    @staticmethod
+    def _unpad(s):
+        return s[:-ord(s[len(s) - 1:])]
+    def decrypt(self, enc):
+        iv = enc[:AES.block_size]
+        cipher = AES.new(self.key, AES.MODE_CBC, iv)
+        return  self._unpad(cipher.decrypt(enc[AES.block_size:]))
+    def decrypt_string(self, enc):
+        enc = base64.b64decode(enc)
+        return  self.decrypt(enc).decode('utf8')

+ 33 - 0
utils/image_io.py

@@ -0,0 +1,33 @@
+import json
+import os
+from utils.logger import logger
+
+def save_image(img, message_id, img_key):
+    # 确保目录存在,如果不存在则创建
+    output_directory = f"data/input_images/{message_id}/{img_key}"
+    os.makedirs(output_directory, exist_ok=True)
+
+    # 生成图片文件路径
+    image_path = os.path.join(output_directory, f"{img_key}.jpg")
+
+    # 保存图片
+    with open(image_path, 'wb') as f:
+        f.write(img)
+
+def get_img_key_from_input_images(message_id):
+    input_directory = f"data/input_images/{message_id}"
+    img_key = []
+    for file in os.listdir(input_directory):
+        if not file.endswith(".json"):
+            img_key.append(file)
+    return img_key
+
+def create_image_path(message_id, img_key):
+    # 确保目录存在,如果不存在则创建
+    output_directory = f"data/input_images/{message_id}/{img_key}"
+    os.makedirs(output_directory, exist_ok=True)
+
+def get_img_file(img_path):
+    with open(img_path, 'rb') as f:
+        img = f.read()
+    return img

+ 11 - 0
utils/logger.py

@@ -0,0 +1,11 @@
+import logging
+
+
+logger = logging.getLogger("bot")
+logger.setLevel(logging.INFO)
+fh = logging.FileHandler('stable_diffusion_bot.log')
+fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
+logger.addHandler(fh)
+ch = logging.StreamHandler()
+ch.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) 
+logger.addHandler(ch)

+ 109 - 0
utils/request_api.py

@@ -0,0 +1,109 @@
+import json
+import requests
+import os
+from functools import wraps
+from dotenv import load_dotenv
+from requests_toolbelt import MultipartEncoder
+from utils.logger import logger
+from model.global_model import GlobalModel
+
+load_dotenv() # 加载环境变量
+APP_ID = os.getenv('APP_ID')
+APP_SECRET = os.getenv('APP_SECRET') 
+
+def get_tenant_token():
+    url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal'
+    data = {
+        'app_id': APP_ID,
+        'app_secret': APP_SECRET
+    }
+    r = requests.post(url, json=data)
+    GlobalModel().tenant_token = r.json().get('tenant_access_token')
+
+def keep_tenant_access_token(f):
+    @wraps(f)
+    def decorated(*args, **kwargs):
+        if GlobalModel().tenant_token == "":
+            # tenant_token未初始化
+            get_tenant_token()
+            r = f(*args, **kwargs)
+        else:
+            # tenant_token已初始化
+            try:
+                r = f(*args, **kwargs)
+                if r:
+                    if r.json().get('code') == 99991661 or r.json().get('code') == 99991663:
+                        # tenant_token已过期
+                        get_tenant_token()
+                        r = f(*args, **kwargs)
+                    else:
+                        # 正常返回
+                        return r
+                else:
+                    get_tenant_token()
+                    r = f(*args, **kwargs)
+            except:
+                get_tenant_token()
+                r = f(*args, **kwargs)
+                return r
+        return r
+    return decorated
+
+@keep_tenant_access_token
+def reply_message(message_id, response_json):
+    headers = {
+        'Authorization': 'Bearer ' + GlobalModel().tenant_token,
+        'Content-Type': 'application/json; charset=utf-8'
+    }
+    url = f"https://open.feishu.cn/open-apis/im/v1/messages/{message_id}/reply"
+    res = requests.post(url, json=response_json, headers=headers)
+    if res.status_code == 200 and 'code' in res.json() and res.json()['code'] == 0:
+        logger.info(message_id + '回复成功')
+        return res
+    else:
+        logger.error(message_id + '回复失败\n' + str(res.json()))
+        return None
+
+@keep_tenant_access_token
+def get_img(message_id, img_key):
+    headers = {
+        'Authorization': 'Bearer ' + GlobalModel().tenant_token
+    }
+    url = f"https://open.feishu.cn/open-apis/im/v1/messages/{message_id}/resources/{img_key}?type=image"
+    response = requests.get(url, headers=headers)
+    return response
+
+@keep_tenant_access_token
+def update_message(message_id, update_json):
+    headers = {
+        'Authorization': 'Bearer ' + GlobalModel().tenant_token,
+        'Content-Type': 'application/json; charset=utf-8'
+    }
+    url = f"https://open.feishu.cn/open-apis/im/v1/messages/{message_id}"
+    res = requests.patch(url, json=update_json, headers=headers)
+    if res.status_code == 200 and 'code' in res.json() and res.json()['code'] == 0:
+        logger.info(message_id + '更新成功')
+        return res
+    else:
+        logger.error(message_id + '更新失败\n' + str(res.json()))
+        return None
+    
+@keep_tenant_access_token
+def upload_image(img_key, image_data, note):
+    url = "https://open.feishu.cn/open-apis/im/v1/images"
+    data = {
+        "image_type": "message",
+        'image': image_data
+    }
+    multi_form = MultipartEncoder(data)
+    headers = {
+        'Authorization': 'Bearer ' + GlobalModel().tenant_token,
+        'Content-Type': multi_form.content_type
+    }
+    res = requests.post(url, data=multi_form, headers=headers)
+    if res.status_code == 200 and 'code' in res.json() and res.json()['code'] == 0:
+        logger.info(img_key + "-" + note + '上传成功')
+        return res
+    else:
+        logger.error(img_key + "-"  + note + '上传失败\n' + str(res.json()))
+        return None

+ 10 - 0
utils/singleton.py

@@ -0,0 +1,10 @@
+class Singleton:
+    _instances = {}
+
+    def __init__(self, cls):
+        self._cls = cls
+
+    def __call__(self, *args, **kwargs):
+        if self._cls not in self._instances:
+            self._instances[self._cls] = self._cls(*args, **kwargs)
+        return self._instances[self._cls]

+ 115 - 0
utils/template_io.py

@@ -0,0 +1,115 @@
+import json
+import os
+from utils.config_io import if_mission_done, if_mission_fail, load_config_from_json, load_lora_options
+from utils.logger import logger
+
+
+# 从文件中加载模板
+def load_template(template_name, reply_message_id = None):
+    template_path = f"template/{template_name}.json"
+    with open(template_path, 'r', encoding='utf-8') as f:
+        template = json.load(f)
+    if template_name in ['choose_card', 'upload_card', 'start_card']:
+        if template_name == 'choose_card' or template_name == 'upload_card':
+            index = 3
+        elif template_name == 'start_card':
+            index = 4
+        lora_options= load_lora_options()
+        if reply_message_id:
+            config_data = load_config_from_json(reply_message_id)
+            template['elements'][index]['actions'][0]['initial_option'] = config_data['action']
+            template['elements'][index]['actions'][1]['initial_option'] = config_data['lora']
+        for option, value in lora_options.items():
+            template['elements'][index]['actions'][1]['options'].append(
+                {
+                    "text": {
+                        "tag": "lark_md",
+                        "content": option
+                    },
+                    "value": value
+                }
+            )
+    return template
+
+# 加载用户已选选项
+def load_initial_option_from_config(reply_message_id, card_content):
+    config_data = load_config_from_json(reply_message_id)
+    card_content['elements'][4]['actions'][0]['initial_option'] = config_data['action']
+    card_content['elements'][4]['actions'][1]['initial_option'] = config_data['lora']
+    return card_content
+
+# 保存waiting card
+def save_process_card(card_content, reply_message_id):
+    output_directory = f"data/input_images/{reply_message_id}"
+    os.makedirs(output_directory, exist_ok=True)
+    json_path = os.path.join(output_directory, "process_card.json")
+    with open(json_path, 'w', encoding='utf-8') as f:
+        json.dump(card_content, f, ensure_ascii=False, indent=4)
+
+# 更新waiting card
+def update_process_card(reply_message_id, img_key, upload_key):
+    output_directory = f"data/input_images/{reply_message_id}"
+    os.makedirs(output_directory, exist_ok=True)
+    json_path = os.path.join(output_directory, "process_card.json")
+    with open(json_path, 'r', encoding='utf-8') as f:
+        card_content = json.load(f)
+    # 更新图片
+    for column in card_content['elements'][1]['columns']:
+        for i, element in enumerate(column["elements"]):
+            if element.get("img_key") == img_key:
+                column["elements"].append(
+                    {
+                        "tag": "img",
+                        "img_key": upload_key,
+                        "alt": {
+                            "tag": "plain_text",
+                            "content": ""
+                        },
+                        "mode": "fit_horizontal",
+                        "preview": True
+                    }
+                )
+                break
+    config_data = load_config_from_json(reply_message_id)
+    user_open_id = config_data['user_open_id']
+    # 更新任务状态
+    if if_mission_fail(reply_message_id):
+        logger.error(reply_message_id + "任务失败")
+        card_content['header'] = {
+            "title": {
+                "tag": "plain_text",
+                "content": "任务执行出现错误,请检查"
+            },
+            "template": "red"
+        }
+        card_content['elements'][0] = {
+            "tag": "div",
+            "text": {
+                "content": f"<at id=\"{user_open_id}\"></at>---------------/(ㄒoㄒ)/~~",
+                "tag": "lark_md",
+            }
+        }
+        del card_content['elements'][4]
+        del card_content['elements'][4]
+    else:
+        if if_mission_done(reply_message_id):
+            logger.info(reply_message_id + "任务成功")
+            card_content['header'] = {
+                "title": {
+                    "tag": "plain_text",
+                    "content": "任务完成"
+                },
+                "template": "green"
+            }
+            card_content['elements'][0] = {
+                "tag": "div",
+                "text": {
+                    "content": f"<at id=\"{user_open_id}\"></at>---------------o(* ̄▽ ̄*)ブ",
+                    "tag": "lark_md"
+                }
+            }
+            del card_content['elements'][4]
+            del card_content['elements'][4]
+    with open(json_path, 'w', encoding='utf-8') as f:
+        json.dump(card_content, f, ensure_ascii=False, indent=4)
+    return card_content