diff --git a/private_gpt/components/llm/prompt_helper.py b/private_gpt/components/llm/prompt_helper.py index b550020b9..0432e4964 100644 --- a/private_gpt/components/llm/prompt_helper.py +++ b/private_gpt/components/llm/prompt_helper.py @@ -169,7 +169,7 @@ class Llama3PromptStyle(AbstractPromptStyle): """ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: - prompt = self.BOS + prompt = "" has_system_message = False for i, message in enumerate(messages): @@ -189,8 +189,7 @@ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: # Add default system prompt if no system message was provided if not has_system_message: prompt = ( - f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" - + prompt[len(self.BOS) :] + f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + prompt ) # TODO: Implement tool handling logic @@ -199,7 +198,7 @@ def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str: def _completion_to_prompt(self, completion: str) -> str: return ( - f"{self.BOS}{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}" f"{self.ASSISTANT_INST}\n\n" ) diff --git a/tests/test_prompt_helper.py b/tests/test_prompt_helper.py index ad9349c8b..c0653e28f 100644 --- a/tests/test_prompt_helper.py +++ b/tests/test_prompt_helper.py @@ -150,7 +150,7 @@ def test_llama3_prompt_style_format(): ] expected_prompt = ( - "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "<|start_header_id|>system<|end_header_id|>\n\n" "You are a helpful assistant<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n" "Hello, how are you doing?<|eot_id|>" @@ -166,7 +166,7 @@ def test_llama3_prompt_style_with_default_system(): ChatMessage(content="Hello!", role=MessageRole.USER), ] expected = ( - "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "<|start_header_id|>system<|end_header_id|>\n\n" f"{prompt_style.DEFAULT_SYSTEM_PROMPT}<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\n" @@ -185,7 +185,7 @@ def test_llama3_prompt_style_with_assistant_response(): ] expected_prompt = ( - "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + "<|start_header_id|>system<|end_header_id|>\n\n" "You are a helpful assistant<|eot_id|>" "<|start_header_id|>user<|end_header_id|>\n\n" "What is the capital of France?<|eot_id|>"