Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion engine/extensions/remote-engine/remote_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
status["has_error"] = true;
status["is_stream"] = true;
status["status_code"] = k400BadRequest;
context->need_stop = false;
(*context->callback)(std::move(status), std::move(check_error));
return size * nmemb;
}
Expand All @@ -58,7 +59,8 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = 200;
status["status_code"] = k200OK;
context->need_stop = false;
(*context->callback)(std::move(status), Json::Value());
break;
}
Expand Down Expand Up @@ -169,6 +171,15 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest(

curl_slist_free_all(headers);
curl_easy_cleanup(curl);
if (context.need_stop) {
CTL_DBG("No stop message received, need to stop");
Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
(*context.callback)(std::move(status), Json::Value());
}
return response;
}

Expand Down Expand Up @@ -626,6 +637,9 @@ void RemoteEngine::HandleChatCompletion(

try {
response_json["stream"] = false;
if (!response_json.isMember("model")) {
response_json["model"] = model;
}
response_str = renderer_.Render(template_str, response_json);
} catch (const std::exception& e) {
throw std::runtime_error("Template rendering error: " +
Expand All @@ -649,6 +663,7 @@ void RemoteEngine::HandleChatCompletion(
Json::Value error;
error["error"] = "Failed to parse response";
callback(std::move(status), std::move(error));
LOG_WARN << "Failed to parse response: " << response_str;
return;
}

Expand Down
1 change: 1 addition & 0 deletions engine/extensions/remote-engine/remote_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct StreamContext {
std::string model;
extensions::TemplateRenderer& renderer;
std::string stream_template;
bool need_stop = true;
};
struct CurlResponse {
std::string body;
Expand Down
5 changes: 1 addition & 4 deletions engine/extensions/template_renderer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ std::string TemplateRenderer::Render(const std::string& tmpl,

LOG_DEBUG << "Result: " << result;

// Validate JSON
auto parsed = nlohmann::json::parse(result);

return result;
} catch (const std::exception& e) {
LOG_ERROR << "Template rendering failed: " << e.what();
Expand Down Expand Up @@ -133,4 +130,4 @@ std::string TemplateRenderer::RenderFile(const std::string& template_path,
e.what());
}
}
} // namespace remote_engine
} // namespace extensions
233 changes: 233 additions & 0 deletions engine/test/components/test_remote_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,239 @@ TEST_F(RemoteEngineTest, AnthropicResponse) {
EXPECT_TRUE(res_json["choices"][0]["message"]["content"].isNull());
}

TEST_F(RemoteEngineTest, CohereRequest) {
std::string tpl =
R"({
{% for key, value in input_request %}
{% if key == "messages" %}
{% if input_request.messages.0.role == "system" %}
"preamble": "{{ input_request.messages.0.content }}",
{% if length(input_request.messages) > 2 %}
"chatHistory": [
{% for message in input_request.messages %}
{% if not loop.is_first and not loop.is_last %}
{"role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": "{{ message.content }}" } {% if loop.index < length(input_request.messages) - 2 %},{% endif %}
{% endif %}
{% endfor %}
],
{% endif %}
"message": "{{ last(input_request.messages).content }}"
{% else %}
{% if length(input_request.messages) > 2 %}
"chatHistory": [
{% for message in input_request.messages %}
{% if not loop.is_last %}
{ "role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": "{{ message.content }}" } {% if loop.index < length(input_request.messages) - 2 %},{% endif %}
{% endif %}
{% endfor %}
],
{% endif %}
"message": "{{ last(input_request.messages).content }}"
{% endif %}
{% if not loop.is_last %},{% endif %}
{% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %}
"{{ key }}": {{ tojson(value) }}
{% if not loop.is_last %},{% endif %}
{% endif %}
{% endfor %} })";
{
std::string message_with_system = R"({
"engine" : "cohere",
"max_tokens" : 1024,
"messages": [
{"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."},
{"role": "user", "content": "Hello, world"},
{"role": "assistant", "content": "The man who is widely credited with discovering gravity is Sir Isaac Newton"},
{"role": "user", "content": "How are you today?"}
],
"model": "command-r-plus-08-2024",
"stream" : true
})";

auto data = json_helper::ParseJsonString(message_with_system);

extensions::TemplateRenderer rdr;
auto res = rdr.Render(tpl, data);

auto res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(data["model"].asString(), res_json["model"].asString());
EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt());
for (auto const& msg : data["messages"]) {
if (msg["role"].asString() == "system") {
EXPECT_EQ(msg["content"].asString(), res_json["preamble"].asString());
}
}
EXPECT_EQ(res_json["message"].asString(), "How are you today?");
}

{
std::string message_without_system = R"({
"messages": [
{"role": "user", "content": "Hello, world"}
],
"model": "command-r-plus-08-2024",
"max_tokens": 1024,
})";

auto data = json_helper::ParseJsonString(message_without_system);

extensions::TemplateRenderer rdr;
auto res = rdr.Render(tpl, data);

auto res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(data["model"].asString(), res_json["model"].asString());
EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt());
EXPECT_EQ(data["messages"][0]["content"].asString(),
res_json["message"].asString());
}
}

TEST_F(RemoteEngineTest, CohereResponse) {
std::string tpl = R"(
{% if input_request.stream %}
{"object": "chat.completion.chunk",
"model": "{{ input_request.model }}",
"choices": [{"index": 0, "delta": { {% if input_request.event_type == "text-generation" %} "role": "assistant", "content": "{{ input_request.text }}" {% else %} "role": "assistant", "content": null {% endif %} },
{% if input_request.event_type == "stream-end" %} "finish_reason": "{{ input_request.finish_reason }}" {% else %} "finish_reason": null {% endif %} }]
}
{% else %}
{"id": "{{ input_request.generation_id }}",
"created": null,
"object": "chat.completion",
"model": "{{ input_request.model }}",
"choices": [{ "index": 0, "message": { "role": "assistant", "content": {% if not input_request.text %} null {% else %} "{{input_request.text}}" {% endif %}, "refusal": null }, "logprobs": null, "finish_reason": "{{ input_request.finish_reason }}" } ], "usage": { "prompt_tokens": {{ input_request.meta.tokens.input_tokens }}, "completion_tokens": {{ input_request.meta.tokens.output_tokens }}, "total_tokens": {{ input_request.meta.tokens.input_tokens + input_request.meta.tokens.output_tokens }}, "prompt_tokens_details": { "cached_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "system_fingerprint": "fp_6b68a8204b"} {% endif %})";
std::string message = R"({
"event_type": "text-generation",
"text": " help"
})";
auto data = json_helper::ParseJsonString(message);
data["stream"] = true;
data["model"] = "cohere";
extensions::TemplateRenderer rdr;
auto res = rdr.Render(tpl, data);
auto res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(res_json["choices"][0]["delta"]["content"].asString(), " help");

message = R"(
{
"event_type": "stream-end",
"response": {
"text": "Hello! How can I help you today?",
"generation_id": "29f14a5a-11de-4cae-9800-25e4747408ea",
"chat_history": [
{
"role": "USER",
"message": "hello world!"
},
{
"role": "CHATBOT",
"message": "Hello! How can I help you today?"
}
],
"finish_reason": "COMPLETE",
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"input_tokens": 3,
"output_tokens": 9
},
"tokens": {
"input_tokens": 69,
"output_tokens": 9
}
}
},
"finish_reason": "COMPLETE"
})";
data = json_helper::ParseJsonString(message);
data["stream"] = true;
data["model"] = "cohere";
res = rdr.Render(tpl, data);
res_json = json_helper::ParseJsonString(res);
EXPECT_TRUE(res_json["choices"][0]["delta"]["content"].isNull());

// non-stream
message = R"(
{
"text": "Isaac Newton was born on 25 December 1642 (Old Style) \n\nor 4 January 1643 (New Style).",
"generation_id": "0385c7cf-4247-43a3-a450-b25b547a31e1",
"citations": [
{
"start": 25,
"end": 41,
"text": "25 December 1642",
"document_ids": [
"web-search_0"
]
}
],
"search_queries": [
{
"text": "Isaac Newton birth year",
"generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d"
}
],
"search_results": [
{
"connector": {
"id": "web-search"
},
"document_ids": [
"web-search_0"
],
"search_query": {
"text": "Isaac Newton birth year",
"generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d"
}
}
],
"finish_reason": "COMPLETE",
"chat_history": [
{
"role": "USER",
"message": "Who discovered gravity?"
},
{
"role": "CHATBOT",
"message": "The man who is widely credited with discovering gravity is Sir Isaac Newton"
},
{
"role": "USER",
"message": "What year was he born?"
},
{
"role": "CHATBOT",
"message": "Isaac Newton was born on 25 December 1642 (Old Style) or 4 January 1643 (New Style)."
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"input_tokens": 31738,
"output_tokens": 35
},
"tokens": {
"input_tokens": 32465,
"output_tokens": 205
}
}
}
)";

data = json_helper::ParseJsonString(message);
data["stream"] = false;
data["model"] = "cohere";
res = rdr.Render(tpl, data);
res_json = json_helper::ParseJsonString(res);
EXPECT_EQ(res_json["choices"][0]["message"]["content"].asString(),
"Isaac Newton was born on 25 December 1642 (Old Style) \n\nor 4 "
"January 1643 (New Style).");
}

TEST_F(RemoteEngineTest, HeaderTemplate) {
{
std::string header_template =
Expand Down
Loading