Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit a51c2c5

Browse files
Feat/structured output (#308)
* Add openai compatible embedding * Feat: add structured output in chat completion request * chore: remove unnecessary cout
1 parent 2368b3f commit a51c2c5

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

src/chat_completion_request.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ struct ChatCompletionRequest {
8888
bool include_usage = false;
8989
std::string grammar;
9090
Json::Value logit_bias = Json::Value(Json::arrayValue);
91+
Json::Value json_schema;
9192

9293
static Json::Value ConvertLogitBiasToArray(const Json::Value& input) {
9394
Json::Value result(Json::arrayValue);
@@ -155,6 +156,7 @@ inline ChatCompletionRequest fromJson(std::shared_ptr<Json::Value> jsonBody) {
155156
completion.min_keep = (*jsonBody).get("min_keep", 0).asInt();
156157
completion.n = (*jsonBody).get("n", 1).asInt();
157158
completion.grammar = (*jsonBody).get("grammar", "").asString();
159+
completion.json_schema = (*jsonBody).get("response_format", Json::Value::null);
158160
const Json::Value& input_logit_bias = (*jsonBody)["logit_bias"];
159161
if (!input_logit_bias.isNull()) {
160162
completion.logit_bias =

src/llama_engine.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
#include <cmath>
44
#include <limits>
55
#include <optional>
6+
#include "json-schema-to-grammar.h"
67
#include "json/writer.h"
78
#include "llama_utils.h"
89
#include "trantor/utils/Logger.h"
910

11+
1012
#if defined(_WIN32)
1113
#include <windows.h>
1214
#include <codecvt>
@@ -56,6 +58,7 @@ bool AreAllElementsInt32(const Json::Value& arr) {
5658
}
5759
// Check if value is within int32_t range
5860
auto value = element.asInt();
61+
5962
if (value < std::numeric_limits<int32_t>::min() ||
6063
value > std::numeric_limits<int32_t>::max()) {
6164
return false;
@@ -748,6 +751,15 @@ void LlamaEngine::HandleInferenceImpl(
748751
data["n_probs"] = completion.n_probs;
749752
data["min_keep"] = completion.min_keep;
750753
data["grammar"] = completion.grammar;
754+
if (!completion.json_schema.isNull() &&
755+
(completion.json_schema.isMember("type") &&
756+
(completion.json_schema["type"] == "json_object" ||
757+
completion.json_schema["type"] == "json_schema"))) {
758+
759+
data["grammar"] =
760+
json_schema_to_grammar(llama::inferences::ConvertJsonCppToNlohmann(
761+
completion.json_schema["json_schema"]["schema"]));
762+
}
751763
data["n"] = completion.n; // number of choices to return
752764
json arr = json::array();
753765
for (const auto& elem : completion.logit_bias) {
@@ -1039,7 +1051,6 @@ void LlamaEngine::HandleInferenceImpl(
10391051
status["is_stream"] = false;
10401052
status["status_code"] = k200OK;
10411053
cb(std::move(status), std::move(respData));
1042-
10431054
LOG_INFO << "Request " << request_id << ": " << "Inference completed";
10441055
}
10451056
});
@@ -1091,6 +1102,7 @@ void LlamaEngine::HandleEmbeddingImpl(
10911102
prompt_tokens +=
10921103
static_cast<int>(result.result_json["tokens_evaluated"]);
10931104
std::vector<float> embedding_result = result.result_json["embedding"];
1105+
10941106
responseData.append(
10951107
CreateEmbeddingPayload(embedding_result, 0, is_base64));
10961108
} else {
@@ -1128,6 +1140,7 @@ void LlamaEngine::HandleEmbeddingImpl(
11281140
prompt_tokens += cur_pt;
11291141
std::vector<float> embedding_result =
11301142
result.result_json["embedding"];
1143+
11311144
responseData.append(
11321145
CreateEmbeddingPayload(embedding_result, i, is_base64));
11331146
}

0 commit comments

Comments
 (0)