prompter.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """
  2. A dedicated helper to manage templates and prompt building.
  3. """
  4. import json
  5. import os.path as osp
  6. from typing import Union
  7. class Prompter(object):
  8. __slots__ = ("template", "_verbose")
  9. def __init__(self, template_name: str = "", verbose: bool = False):
  10. self._verbose = verbose
  11. if not template_name:
  12. # Enforce the default here, so the constructor can be called with '' and will not break.
  13. template_name = "alpaca"
  14. file_name = osp.join("templates", f"{template_name}.json")
  15. if not osp.exists(file_name):
  16. raise ValueError(f"Can't read {file_name}")
  17. with open(file_name) as fp:
  18. self.template = json.load(fp)
  19. if self._verbose:
  20. print(
  21. f"Using prompt template {template_name}: {self.template['description']}"
  22. )
  23. def generate_prompt(
  24. self,
  25. instruction: str,
  26. input: Union[None, str] = None,
  27. label: Union[None, str] = None,
  28. ) -> str:
  29. # returns the full prompt from instruction and optional input
  30. # if a label (=response, =output) is provided, it's also appended.
  31. if input:
  32. res = self.template["prompt_input"].format(
  33. instruction=instruction, input=input
  34. )
  35. else:
  36. res = self.template["prompt_no_input"].format(
  37. instruction=instruction
  38. )
  39. if label:
  40. res = f"{res}{label}"
  41. if self._verbose:
  42. print(res)
  43. return res
  44. def get_response(self, output: str) -> str:
  45. return output.split(self.template["response_split"])[1].strip()