segment.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import json
  2. import os
  3. import requests
  4. import io
  5. import base64
  6. from PIL import Image, PngImagePlugin
  7. from dotenv import load_dotenv
  8. from utils.logger import logger
  9. load_dotenv() # 加载环境变量
  10. segment_url = os.getenv('segment_url')
  11. def get_segment_mask(img_key, img_path, output_directory):
  12. with open(img_path, 'rb') as f:
  13. image_data = f.read()
  14. image_base64 = base64.b64encode(image_data).decode('utf-8')
  15. payload = {
  16. "sam_model_name": "sam_vit_h_4b8939.pth",
  17. "input_image": image_base64,
  18. "sam_positive_points": [],
  19. "sam_negative_points": [],
  20. "dino_enabled": True,
  21. "dino_model_name": "GroundingDINO_SwinB (938MB)",
  22. "dino_text_prompt": "clothes",
  23. "dino_box_threshold": 0.2
  24. }
  25. response = requests.post(url=f'{segment_url}/sam/sam-predict', json=payload)
  26. output_path = output_directory + f"/{img_key}-mask.jpg"
  27. r=response.json()
  28. try:
  29. image = Image.open(io.BytesIO(base64.b64decode(r['masks'][2].split(",", 1)[0])))
  30. image.save(output_path)
  31. return output_path
  32. except:
  33. logger.error('segment error' + str(r))
  34. return None