From edb2ee4d2a533960c1d8a7debfd17e3c68e4a5e6 Mon Sep 17 00:00:00 2001 From: Chao Jiang Date: Sat, 13 Apr 2024 11:38:29 +0800 Subject: [PATCH 1/4] Add chat template for command-r model series --- llama.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/llama.cpp b/llama.cpp index 83dd55efa99e6..919ba77d68592 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16476,6 +16476,21 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "### Response:\n"; } + } else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) { + // CohereForAI/c4ai-command-r-plus + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << message->content << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "user") { + ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << message->content << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "assistant") { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << message->content << "<|END_OF_TURN_TOKEN|>"; + } + } + if (add_ass) { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; + } } else { // template not supported return -1; From 955efb618193e446425364000bf1242c44569dd4 Mon Sep 17 00:00:00 2001 From: Chao Jiang Date: Sat, 13 Apr 2024 06:55:55 +0000 Subject: [PATCH 2/4] Fix indentation --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 919ba77d68592..69a6339ee3fad 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16477,7 +16477,7 @@ static int32_t llama_chat_apply_template_internal( ss << "### Response:\n"; } } else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) { - // CohereForAI/c4ai-command-r-plus + // CohereForAI/c4ai-command-r-plus for (auto message : chat) { std::string role(message->role); if (role == "system") { From 3f1aa54a0b1f982613d47a57a4847c5808875c3c Mon Sep 17 00:00:00 2001 From: Chao Jiang Date: Sun, 14 Apr 2024 00:42:54 +0800 Subject: [PATCH 3/4] Add chat template test for command-r models and update the implementation to trim whitespaces --- llama.cpp | 6 +++--- tests/test-chat-template.cpp | 5 +++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 69a6339ee3fad..a3c70cb8c5891 100644 --- a/llama.cpp +++ b/llama.cpp @@ -16481,11 +16481,11 @@ static int32_t llama_chat_apply_template_internal( for (auto message : chat) { std::string role(message->role); if (role == "system") { - ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << message->content << "<|END_OF_TURN_TOKEN|>"; + ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; } else if (role == "user") { - ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << message->content << "<|END_OF_TURN_TOKEN|>"; + ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; } else if (role == "assistant") { - ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << message->content << "<|END_OF_TURN_TOKEN|>"; + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; } } if (add_ass) { diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 73c3536fdb878..9bc1ab42218d8 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -45,6 +45,8 @@ int main(void) { // Orca-Vicuna // No template included in tokenizer_config.json, so this template likely needs to be manually set. "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", + // CohereForAI/c4ai-command-r-plus + "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}" }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B @@ -69,6 +71,8 @@ int main(void) { "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", // Orca-Vicuna "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + // CohereForAI/c4ai-command-r-plus + "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", }; std::vector formatted_chat(1024); int32_t res; @@ -93,6 +97,7 @@ int main(void) { formatted_chat.resize(res); std::string output(formatted_chat.data(), formatted_chat.size()); std::cout << output << "\n-------------------------\n"; + std::cout << expected << "\n-------------------------\n"; assert(output == expected); } return 0; From 38c9cedd81b2371203339bae5145ec59da7eadf2 Mon Sep 17 00:00:00 2001 From: Chao Jiang Date: Sun, 14 Apr 2024 00:46:54 +0800 Subject: [PATCH 4/4] Remove debug print --- tests/test-chat-template.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 9bc1ab42218d8..522cc7d0d9e84 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -97,7 +97,6 @@ int main(void) { formatted_chat.resize(res); std::string output(formatted_chat.data(), formatted_chat.size()); std::cout << output << "\n-------------------------\n"; - std::cout << expected << "\n-------------------------\n"; assert(output == expected); } return 0;