12345678910111213141516171819202122232425262728293031323334353637 |
- import json
- import os
- import requests
- import io
- import base64
- from PIL import Image, PngImagePlugin
- from dotenv import load_dotenv
- from utils.logger import logger
- load_dotenv() # 加载环境变量
- segment_url = os.getenv('segment_url')
- def get_segment_mask(img_key, img_path, output_directory):
- with open(img_path, 'rb') as f:
- image_data = f.read()
- image_base64 = base64.b64encode(image_data).decode('utf-8')
- payload = {
- "sam_model_name": "sam_vit_h_4b8939.pth",
- "input_image": image_base64,
- "sam_positive_points": [],
- "sam_negative_points": [],
- "dino_enabled": True,
- "dino_model_name": "GroundingDINO_SwinB (938MB)",
- "dino_text_prompt": "clothes",
- "dino_box_threshold": 0.2
- }
- response = requests.post(url=f'{segment_url}/sam/sam-predict', json=payload)
- output_path = output_directory + f"/{img_key}-mask.jpg"
- r=response.json()
- try:
- image = Image.open(io.BytesIO(base64.b64decode(r['masks'][2].split(",", 1)[0])))
- image.save(output_path)
- return output_path
- except:
- logger.error('segment error' + str(r))
- return None
|