|
3 | 3 | #include <cmath> |
4 | 4 | #include <limits> |
5 | 5 | #include <optional> |
| 6 | +#include "json-schema-to-grammar.h" |
6 | 7 | #include "json/writer.h" |
7 | 8 | #include "llama_utils.h" |
8 | 9 | #include "trantor/utils/Logger.h" |
9 | 10 |
|
| 11 | + |
10 | 12 | #if defined(_WIN32) |
11 | 13 | #include <windows.h> |
12 | 14 | #include <codecvt> |
@@ -56,6 +58,7 @@ bool AreAllElementsInt32(const Json::Value& arr) { |
56 | 58 | } |
57 | 59 | // Check if value is within int32_t range |
58 | 60 | auto value = element.asInt(); |
| 61 | + |
59 | 62 | if (value < std::numeric_limits<int32_t>::min() || |
60 | 63 | value > std::numeric_limits<int32_t>::max()) { |
61 | 64 | return false; |
@@ -748,6 +751,15 @@ void LlamaEngine::HandleInferenceImpl( |
748 | 751 | data["n_probs"] = completion.n_probs; |
749 | 752 | data["min_keep"] = completion.min_keep; |
750 | 753 | data["grammar"] = completion.grammar; |
| 754 | + if (!completion.json_schema.isNull() && |
| 755 | + (completion.json_schema.isMember("type") && |
| 756 | + (completion.json_schema["type"] == "json_object" || |
| 757 | + completion.json_schema["type"] == "json_schema"))) { |
| 758 | + |
| 759 | + data["grammar"] = |
| 760 | + json_schema_to_grammar(llama::inferences::ConvertJsonCppToNlohmann( |
| 761 | + completion.json_schema["json_schema"]["schema"])); |
| 762 | + } |
751 | 763 | data["n"] = completion.n; // number of choices to return |
752 | 764 | json arr = json::array(); |
753 | 765 | for (const auto& elem : completion.logit_bias) { |
@@ -1039,7 +1051,6 @@ void LlamaEngine::HandleInferenceImpl( |
1039 | 1051 | status["is_stream"] = false; |
1040 | 1052 | status["status_code"] = k200OK; |
1041 | 1053 | cb(std::move(status), std::move(respData)); |
1042 | | - |
1043 | 1054 | LOG_INFO << "Request " << request_id << ": " << "Inference completed"; |
1044 | 1055 | } |
1045 | 1056 | }); |
@@ -1091,6 +1102,7 @@ void LlamaEngine::HandleEmbeddingImpl( |
1091 | 1102 | prompt_tokens += |
1092 | 1103 | static_cast<int>(result.result_json["tokens_evaluated"]); |
1093 | 1104 | std::vector<float> embedding_result = result.result_json["embedding"]; |
| 1105 | + |
1094 | 1106 | responseData.append( |
1095 | 1107 | CreateEmbeddingPayload(embedding_result, 0, is_base64)); |
1096 | 1108 | } else { |
@@ -1128,6 +1140,7 @@ void LlamaEngine::HandleEmbeddingImpl( |
1128 | 1140 | prompt_tokens += cur_pt; |
1129 | 1141 | std::vector<float> embedding_result = |
1130 | 1142 | result.result_json["embedding"]; |
| 1143 | + |
1131 | 1144 | responseData.append( |
1132 | 1145 | CreateEmbeddingPayload(embedding_result, i, is_base64)); |
1133 | 1146 | } |
|
0 commit comments