diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index 69e25258f..34fd2fc2e 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -23,10 +23,7 @@ def __init__(self, settings: Settings) -> None: case "local": from llama_index.llms import LlamaCPP - prompt_style_cls = get_prompt_style(settings.local.prompt_style) - prompt_style = prompt_style_cls( - default_system_prompt=settings.local.default_system_prompt - ) + prompt_style = get_prompt_style(settings.local.prompt_style) self.llm = LlamaCPP( model_path=str(models_path / settings.local.llm_hf_model_file), diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index e47b3fb9c..a8ca60f27 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -5,7 +5,6 @@ from llama_index.llms import ChatMessage, MessageRole from llama_index.llms.llama_utils import ( - DEFAULT_SYSTEM_PROMPT, completion_to_prompt, messages_to_prompt, ) @@ -29,7 +28,6 @@ class AbstractPromptStyle(abc.ABC): series of messages into a prompt. """ - @abc.abstractmethod def __init__(self, *args: Any, **kwargs: Any) -> None: logger.debug("Initializing prompt_style=%s", self.__class__.__name__) @@ -52,15 +50,6 @@ def completion_to_prompt(self, completion: str) -> str: return prompt -class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC): - _DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT - - def __init__(self, default_system_prompt: str | None) -> None: - super().__init__() - logger.debug("Got default_system_prompt='%s'", default_system_prompt) - self.default_system_prompt = default_system_prompt - - class DefaultPromptStyle(AbstractPromptStyle): """Default prompt style that uses the defaults from llama_utils. @@ -83,7 +72,7 @@ def _completion_to_prompt(self, completion: str) -> str: return "" -class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt): +class Llama2PromptStyle(AbstractPromptStyle): """Simple prompt style that just uses the default llama_utils functions. It transforms the sequence of messages into a prompt that should look like: @@ -94,18 +83,14 @@ class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt): ``` """ - def __init__(self, default_system_prompt: str | None = None) -> None: - # If no system prompt is given, the default one of the implementation is used. - super().__init__(default_system_prompt=default_system_prompt) - def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - return messages_to_prompt(messages, self.default_system_prompt) + return messages_to_prompt(messages) def _completion_to_prompt(self, completion: str) -> str: - return completion_to_prompt(completion, self.default_system_prompt) + return completion_to_prompt(completion) -class TagPromptStyle(AbstractPromptStyleWithSystemPrompt): +class TagPromptStyle(AbstractPromptStyle): """Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`. It transforms the sequence of messages into a prompt that should look like: @@ -119,37 +104,8 @@ class TagPromptStyle(AbstractPromptStyleWithSystemPrompt): FIXME: should we add surrounding `` and `` tags, like in llama2? """ - def __init__(self, default_system_prompt: str | None = None) -> None: - # We have to define a default system prompt here as the LLM will not - # use the default llama_utils functions. - default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT - super().__init__(default_system_prompt) - self.system_prompt: str = default_system_prompt - def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - messages = list(messages) - if messages[0].role != MessageRole.SYSTEM: - logger.info( - "Adding system_promt='%s' to the given messages as there are none given in the session", - self.system_prompt, - ) - messages = [ - ChatMessage(content=self.system_prompt, role=MessageRole.SYSTEM), - *messages, - ] - return self._format_messages_to_prompt(messages) - - def _completion_to_prompt(self, completion: str) -> str: - return ( - f"<|system|>: {self.system_prompt.strip()}\n" - f"<|user|>: {completion.strip()}\n" - "<|assistant|>: " - ) - - @staticmethod - def _format_messages_to_prompt(messages: list[ChatMessage]) -> str: """Format message to prompt with `<|ROLE|>: MSG` style.""" - assert messages[0].role == MessageRole.SYSTEM prompt = "" for message in messages: role = message.role @@ -161,19 +117,24 @@ def _format_messages_to_prompt(messages: list[ChatMessage]) -> str: prompt += "<|assistant|>: " return prompt + def _completion_to_prompt(self, completion: str) -> str: + return self._messages_to_prompt( + [ChatMessage(content=completion, role=MessageRole.USER)] + ) + def get_prompt_style( prompt_style: Literal["default", "llama2", "tag"] | None -) -> type[AbstractPromptStyle]: +) -> AbstractPromptStyle: """Get the prompt style to use from the given string. :param prompt_style: The prompt style to use. :return: The prompt style to use. """ if prompt_style is None or prompt_style == "default": - return DefaultPromptStyle + return DefaultPromptStyle() elif prompt_style == "llama2": - return Llama2PromptStyle + return Llama2PromptStyle() elif prompt_style == "tag": - return TagPromptStyle + return TagPromptStyle() raise ValueError(f"Unknown prompt_style='{prompt_style}'") diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 125396c3e..5d6310341 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -108,15 +108,6 @@ class LocalSettings(BaseModel): "`llama2` is the historic behaviour. `default` might work better with your custom models." ), ) - default_system_prompt: str | None = Field( - None, - description=( - "The default system prompt to use for the chat engine. " - "If none is given - use the default system prompt (from the llama_index). " - "Please note that the default prompt might not be the same for all prompt styles. " - "Also note that this is only used if the first message is not a system message. " - ), - ) class EmbeddingSettings(BaseModel): diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py index 1f22a0692..48cac0ba5 100644 --- a/tests/test_prompt_helper.py +++ b/tests/test_prompt_helper.py @@ -18,7 +18,7 @@ ], ) def test_get_prompt_style_success(prompt_style, expected_prompt_style): - assert get_prompt_style(prompt_style) == expected_prompt_style + assert isinstance(get_prompt_style(prompt_style), expected_prompt_style) def test_get_prompt_style_failure(): @@ -45,20 +45,7 @@ def test_tag_prompt_style_format(): def test_tag_prompt_style_format_with_system_prompt(): - system_prompt = "This is a system prompt from configuration." - prompt_style = TagPromptStyle(default_system_prompt=system_prompt) - messages = [ - ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), - ] - - expected_prompt = ( - f"<|system|>: {system_prompt}\n" - "<|user|>: Hello, how are you doing?\n" - "<|assistant|>: " - ) - - assert prompt_style.messages_to_prompt(messages) == expected_prompt - + prompt_style = TagPromptStyle() messages = [ ChatMessage( content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM @@ -94,22 +81,7 @@ def test_llama2_prompt_style_format(): def test_llama2_prompt_style_with_system_prompt(): - system_prompt = "This is a system prompt from configuration." - prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt) - messages = [ - ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER), - ] - - expected_prompt = ( - " [INST] <>\n" - f" {system_prompt} \n" - "<>\n" - "\n" - " Hello, how are you doing? [/INST]" - ) - - assert prompt_style.messages_to_prompt(messages) == expected_prompt - + prompt_style = Llama2PromptStyle() messages = [ ChatMessage( content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM