diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py
index 69e25258f..34fd2fc2e 100644
--- a/private_gpt/components/llm/llm_component.py
+++ b/private_gpt/components/llm/llm_component.py
@@ -23,10 +23,7 @@ def __init__(self, settings: Settings) -> None:
case "local":
from llama_index.llms import LlamaCPP
- prompt_style_cls = get_prompt_style(settings.local.prompt_style)
- prompt_style = prompt_style_cls(
- default_system_prompt=settings.local.default_system_prompt
- )
+ prompt_style = get_prompt_style(settings.local.prompt_style)
self.llm = LlamaCPP(
model_path=str(models_path / settings.local.llm_hf_model_file),
diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py
index e47b3fb9c..a8ca60f27 100644
--- a/private_gpt/components/llm/prompt_helper.py
+++ b/private_gpt/components/llm/prompt_helper.py
@@ -5,7 +5,6 @@
from llama_index.llms import ChatMessage, MessageRole
from llama_index.llms.llama_utils import (
- DEFAULT_SYSTEM_PROMPT,
completion_to_prompt,
messages_to_prompt,
)
@@ -29,7 +28,6 @@ class AbstractPromptStyle(abc.ABC):
series of messages into a prompt.
"""
- @abc.abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.debug("Initializing prompt_style=%s", self.__class__.__name__)
@@ -52,15 +50,6 @@ def completion_to_prompt(self, completion: str) -> str:
return prompt
-class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
- _DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
-
- def __init__(self, default_system_prompt: str | None) -> None:
- super().__init__()
- logger.debug("Got default_system_prompt='%s'", default_system_prompt)
- self.default_system_prompt = default_system_prompt
-
-
class DefaultPromptStyle(AbstractPromptStyle):
"""Default prompt style that uses the defaults from llama_utils.
@@ -83,7 +72,7 @@ def _completion_to_prompt(self, completion: str) -> str:
return ""
-class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt):
+class Llama2PromptStyle(AbstractPromptStyle):
"""Simple prompt style that just uses the default llama_utils functions.
It transforms the sequence of messages into a prompt that should look like:
@@ -94,18 +83,14 @@ class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt):
```
"""
- def __init__(self, default_system_prompt: str | None = None) -> None:
- # If no system prompt is given, the default one of the implementation is used.
- super().__init__(default_system_prompt=default_system_prompt)
-
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
- return messages_to_prompt(messages, self.default_system_prompt)
+ return messages_to_prompt(messages)
def _completion_to_prompt(self, completion: str) -> str:
- return completion_to_prompt(completion, self.default_system_prompt)
+ return completion_to_prompt(completion)
-class TagPromptStyle(AbstractPromptStyleWithSystemPrompt):
+class TagPromptStyle(AbstractPromptStyle):
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
It transforms the sequence of messages into a prompt that should look like:
@@ -119,37 +104,8 @@ class TagPromptStyle(AbstractPromptStyleWithSystemPrompt):
FIXME: should we add surrounding `` and `` tags, like in llama2?
"""
- def __init__(self, default_system_prompt: str | None = None) -> None:
- # We have to define a default system prompt here as the LLM will not
- # use the default llama_utils functions.
- default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
- super().__init__(default_system_prompt)
- self.system_prompt: str = default_system_prompt
-
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
- messages = list(messages)
- if messages[0].role != MessageRole.SYSTEM:
- logger.info(
- "Adding system_promt='%s' to the given messages as there are none given in the session",
- self.system_prompt,
- )
- messages = [
- ChatMessage(content=self.system_prompt, role=MessageRole.SYSTEM),
- *messages,
- ]
- return self._format_messages_to_prompt(messages)
-
- def _completion_to_prompt(self, completion: str) -> str:
- return (
- f"<|system|>: {self.system_prompt.strip()}\n"
- f"<|user|>: {completion.strip()}\n"
- "<|assistant|>: "
- )
-
- @staticmethod
- def _format_messages_to_prompt(messages: list[ChatMessage]) -> str:
"""Format message to prompt with `<|ROLE|>: MSG` style."""
- assert messages[0].role == MessageRole.SYSTEM
prompt = ""
for message in messages:
role = message.role
@@ -161,19 +117,24 @@ def _format_messages_to_prompt(messages: list[ChatMessage]) -> str:
prompt += "<|assistant|>: "
return prompt
+ def _completion_to_prompt(self, completion: str) -> str:
+ return self._messages_to_prompt(
+ [ChatMessage(content=completion, role=MessageRole.USER)]
+ )
+
def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag"] | None
-) -> type[AbstractPromptStyle]:
+) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
:param prompt_style: The prompt style to use.
:return: The prompt style to use.
"""
if prompt_style is None or prompt_style == "default":
- return DefaultPromptStyle
+ return DefaultPromptStyle()
elif prompt_style == "llama2":
- return Llama2PromptStyle
+ return Llama2PromptStyle()
elif prompt_style == "tag":
- return TagPromptStyle
+ return TagPromptStyle()
raise ValueError(f"Unknown prompt_style='{prompt_style}'")
diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py
index 125396c3e..5d6310341 100644
--- a/private_gpt/settings/settings.py
+++ b/private_gpt/settings/settings.py
@@ -108,15 +108,6 @@ class LocalSettings(BaseModel):
"`llama2` is the historic behaviour. `default` might work better with your custom models."
),
)
- default_system_prompt: str | None = Field(
- None,
- description=(
- "The default system prompt to use for the chat engine. "
- "If none is given - use the default system prompt (from the llama_index). "
- "Please note that the default prompt might not be the same for all prompt styles. "
- "Also note that this is only used if the first message is not a system message. "
- ),
- )
class EmbeddingSettings(BaseModel):
diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py
index 1f22a0692..48cac0ba5 100644
--- a/tests/test_prompt_helper.py
+++ b/tests/test_prompt_helper.py
@@ -18,7 +18,7 @@
],
)
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
- assert get_prompt_style(prompt_style) == expected_prompt_style
+ assert isinstance(get_prompt_style(prompt_style), expected_prompt_style)
def test_get_prompt_style_failure():
@@ -45,20 +45,7 @@ def test_tag_prompt_style_format():
def test_tag_prompt_style_format_with_system_prompt():
- system_prompt = "This is a system prompt from configuration."
- prompt_style = TagPromptStyle(default_system_prompt=system_prompt)
- messages = [
- ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
- ]
-
- expected_prompt = (
- f"<|system|>: {system_prompt}\n"
- "<|user|>: Hello, how are you doing?\n"
- "<|assistant|>: "
- )
-
- assert prompt_style.messages_to_prompt(messages) == expected_prompt
-
+ prompt_style = TagPromptStyle()
messages = [
ChatMessage(
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM
@@ -94,22 +81,7 @@ def test_llama2_prompt_style_format():
def test_llama2_prompt_style_with_system_prompt():
- system_prompt = "This is a system prompt from configuration."
- prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt)
- messages = [
- ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
- ]
-
- expected_prompt = (
- " [INST] <>\n"
- f" {system_prompt} \n"
- "<>\n"
- "\n"
- " Hello, how are you doing? [/INST]"
- )
-
- assert prompt_style.messages_to_prompt(messages) == expected_prompt
-
+ prompt_style = Llama2PromptStyle()
messages = [
ChatMessage(
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM