diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index a15f809d..341091ff 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -4,10 +4,17 @@ on:
- pull_request
- workflow_dispatch
env:
- MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf
- MODEL_NAME: codellama-7b.Q2_K.gguf
+
+ REASONING_MODEL_URL: https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories260K.gguf
+ REASONING_MODEL_NAME: stories260K.gguf
+ INFILL_MODEL_URL: https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories260K-infill.gguf
+ INFILL_MODEL_NAME: stories260K-infill.gguf
+ MOE_MODEL_URL: https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/stories15M_MOE-F16.gguf
+ MOE_MODEL_NAME: stories15M_MOE-F16.gguf
RERANKING_MODEL_URL: https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf
RERANKING_MODEL_NAME: jina-reranker-v1-tiny-en-Q4_0.gguf
+ EMBEDDING_MODEL_URL: https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf
+ EMBEDDING_MODEL_NAME: ggml-model-f16.gguf
jobs:
build-and-test-linux:
@@ -23,10 +30,21 @@ jobs:
run: |
mvn compile
.github/build.sh -DLLAMA_VERBOSE=ON
- - name: Download text generation model
- run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
- name: Download reranking model
run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
+
+ - name: Download reasoning calling model
+ run: curl -L ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME}
+
+ - name: Download infill calling model
+ run: curl -L ${INFILL_MODEL_URL} --create-dirs -o models/${INFILL_MODEL_NAME}
+
+ - name: Download MOE model
+ run: curl -L ${MOE_MODEL_URL} --create-dirs -o models/${MOE_MODEL_NAME}
+
+ - name: Download EMBEDDING model
+ run: curl -L ${EMBEDDING_MODEL_URL} --create-dirs -o models/${EMBEDDING_MODEL_NAME}
+
- name: List files in models directory
run: ls -l models/
- name: Run tests
@@ -59,10 +77,22 @@ jobs:
run: |
mvn compile
.github/build.sh ${{ matrix.target.cmake }}
- - name: Download text generaton model model
- run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
+
- name: Download reranking model
run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
+
+ - name: Download reasoning calling model
+ run: curl -L ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME}
+
+ - name: Download infill calling model
+ run: curl -L ${INFILL_MODEL_URL} --create-dirs -o models/${INFILL_MODEL_NAME}
+
+ - name: Download MOE model
+ run: curl -L ${MOE_MODEL_URL} --create-dirs -o models/${MOE_MODEL_NAME}
+
+ - name: Download EMBEDDING model
+ run: curl -L ${EMBEDDING_MODEL_URL} --create-dirs -o models/${EMBEDDING_MODEL_NAME}
+
- name: List files in models directory
run: ls -l models/
- name: Run tests
@@ -87,10 +117,22 @@ jobs:
run: |
mvn compile
.github\build.bat -DLLAMA_VERBOSE=ON
- - name: Download model
- run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME
+
- name: Download reranking model
run: curl -L $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME
+
+ - name: Download reasoning calling model
+ run: curl -L $env:REASONING_MODEL_URL --create-dirs -o models/$env:REASONING_MODEL_NAME
+
+ - name: Download infill calling model
+ run: curl -L $env:INFILL_MODEL_URL --create-dirs -o models/$env:INFILL_MODEL_NAME
+
+ - name: Download MOE model
+ run: curl -L $env:MOE_MODEL_URL --create-dirs -o models/$env:MOE_MODEL_NAME
+
+ - name: Download EMBEDDING model
+ run: curl -L $env:EMBEDDING_MODEL_URL --create-dirs -o models/$env:EMBEDDING_MODEL_NAME
+
- name: List files in models directory
run: ls -l models/
- name: Run tests
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index 64032028..8718221e 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -9,10 +9,16 @@ on:
release:
types: [ created ]
env:
- MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf"
- MODEL_NAME: "codellama-7b.Q2_K.gguf"
+ REASONING_MODEL_URL: "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories260K.gguf"
+ REASONING_MODEL_NAME: "stories260K.gguf"
+ INFILL_MODEL_URL: "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories260K-infill.gguf"
+ INFILL_MODEL_NAME: "stories260K-infill.gguf"
+ MOE_MODEL_URL: "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/stories15M_MOE-F16.gguf"
+ MOE_MODEL_NAME: "stories15M_MOE-F16.gguf"
RERANKING_MODEL_URL: "https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf"
RERANKING_MODEL_NAME: "jina-reranker-v1-tiny-en-Q4_0.gguf"
+ EMBEDDING_MODEL_URL: "https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf"
+ EMBEDDING_MODEL_NAME: "ggml-model-f16.gguf"
jobs:
# todo: doesn't work with the newest llama.cpp version
@@ -146,10 +152,21 @@ jobs:
with:
name: Linux-x86_64-libraries
path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/
- - name: Download text generation model
- run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
- - name: Download reranking model
+
+ - name: Download reranking model
run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
+
+ - name: Download reasoning calling model
+ run: curl -L ${REASONING_MODEL_URL} --create-dirs -o models/${REASONING_MODEL_NAME}
+
+ - name: Download infill calling model
+ run: curl -L ${INFILL_MODEL_URL} --create-dirs -o models/${INFILL_MODEL_NAME}
+
+ - name: Download MOE model
+ run: curl -L ${MOE_MODEL_URL} --create-dirs -o models/${MOE_MODEL_NAME}
+
+ - name: Download EMBEDDING model
+ run: curl -L ${EMBEDDING_MODEL_URL} --create-dirs -o models/${EMBEDDING_MODEL_NAME}
- uses: actions/setup-java@v4
with:
distribution: 'zulu'
diff --git a/.gitignore b/.gitignore
index 274f8687..0f023ba0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -42,4 +42,6 @@ src/test/resources/**/*.gbnf
**/*.etag
**/*.lastModified
-src/main/cpp/llama.cpp/
\ No newline at end of file
+src/main/cpp/llama.cpp/
+/.classpath
+/.project
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8f402fa2..45f44c25 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -25,7 +25,7 @@ set(LLAMA_BUILD_COMMON ON)
FetchContent_Declare(
llama.cpp
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
- GIT_TAG b4916
+ GIT_TAG b4940
)
FetchContent_MakeAvailable(llama.cpp)
diff --git a/pom.xml b/pom.xml
index 3916a9e7..eab32e55 100644
--- a/pom.xml
+++ b/pom.xml
@@ -5,7 +5,7 @@
de.kherud
llama
- 4.1.0
+ 4.1.1
jar
${project.groupId}:${project.artifactId}
@@ -65,6 +65,16 @@
24.1.0
compile
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ 2.18.3
+
+
+
diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp
index ac056b94..70db5d10 100644
--- a/src/main/cpp/jllama.cpp
+++ b/src/main/cpp/jllama.cpp
@@ -1,12 +1,10 @@
#include "jllama.h"
-
#include "arg.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "log.h"
#include "nlohmann/json.hpp"
#include "server.hpp"
-
#include
#include
#include
@@ -16,162 +14,200 @@
// The references remain valid throughout the whole life of the shared library, on `JNI_OnUnload` they are released.
namespace {
-JavaVM *g_vm = nullptr;
-
-// classes
-jclass c_llama_model = nullptr;
-jclass c_llama_iterator = nullptr;
-jclass c_standard_charsets = nullptr;
-jclass c_output = nullptr;
-jclass c_string = nullptr;
-jclass c_hash_map = nullptr;
-jclass c_map = nullptr;
-jclass c_set = nullptr;
-jclass c_entry = nullptr;
-jclass c_iterator = nullptr;
-jclass c_integer = nullptr;
-jclass c_float = nullptr;
-jclass c_biconsumer = nullptr;
-jclass c_llama_error = nullptr;
-jclass c_log_level = nullptr;
-jclass c_log_format = nullptr;
-jclass c_error_oom = nullptr;
-
-// constructors
-jmethodID cc_output = nullptr;
-jmethodID cc_hash_map = nullptr;
-jmethodID cc_integer = nullptr;
-jmethodID cc_float = nullptr;
-
-// methods
-jmethodID m_get_bytes = nullptr;
-jmethodID m_entry_set = nullptr;
-jmethodID m_set_iterator = nullptr;
-jmethodID m_iterator_has_next = nullptr;
-jmethodID m_iterator_next = nullptr;
-jmethodID m_entry_key = nullptr;
-jmethodID m_entry_value = nullptr;
-jmethodID m_map_put = nullptr;
-jmethodID m_int_value = nullptr;
-jmethodID m_float_value = nullptr;
-jmethodID m_biconsumer_accept = nullptr;
-
-// fields
-jfieldID f_model_pointer = nullptr;
-jfieldID f_task_id = nullptr;
-jfieldID f_utf_8 = nullptr;
-jfieldID f_iter_has_next = nullptr;
-jfieldID f_log_level_debug = nullptr;
-jfieldID f_log_level_info = nullptr;
-jfieldID f_log_level_warn = nullptr;
-jfieldID f_log_level_error = nullptr;
-jfieldID f_log_format_json = nullptr;
-jfieldID f_log_format_text = nullptr;
-
-// objects
-jobject o_utf_8 = nullptr;
-jobject o_log_level_debug = nullptr;
-jobject o_log_level_info = nullptr;
-jobject o_log_level_warn = nullptr;
-jobject o_log_level_error = nullptr;
-jobject o_log_format_json = nullptr;
-jobject o_log_format_text = nullptr;
-jobject o_log_callback = nullptr;
-
-/**
- * Convert a Java string to a std::string
- */
-std::string parse_jstring(JNIEnv *env, jstring java_string) {
- auto *const string_bytes = (jbyteArray)env->CallObjectMethod(java_string, m_get_bytes, o_utf_8);
-
- auto length = (size_t)env->GetArrayLength(string_bytes);
- jbyte *byte_elements = env->GetByteArrayElements(string_bytes, nullptr);
-
- std::string string = std::string((char *)byte_elements, length);
-
- env->ReleaseByteArrayElements(string_bytes, byte_elements, JNI_ABORT);
- env->DeleteLocalRef(string_bytes);
-
- return string;
-}
-
-char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const jsize length) {
- auto *const result = static_cast(malloc(length * sizeof(char *)));
+ JavaVM * g_vm = nullptr;
+
+ // classes
+ jclass c_llama_model = nullptr;
+ jclass c_standard_charsets = nullptr;
+ jclass c_string = nullptr;
+ jclass c_hash_map = nullptr;
+ jclass c_map = nullptr;
+ jclass c_set = nullptr;
+ jclass c_entry = nullptr;
+ jclass c_iterator = nullptr;
+ jclass c_integer = nullptr;
+ jclass c_float = nullptr;
+ jclass c_biconsumer = nullptr;
+ jclass c_llama_error = nullptr;
+ jclass c_log_level = nullptr;
+ jclass c_log_format = nullptr;
+ jclass c_error_oom = nullptr;
+ jclass c_charset_class = nullptr;
+
+
+ // constructors
+ jmethodID cc_hash_map = nullptr;
+ jmethodID cc_integer = nullptr;
+ jmethodID cc_float = nullptr;
+
+ // methods
+ jmethodID m_get_bytes = nullptr;
+ jmethodID m_entry_set = nullptr;
+ jmethodID m_set_iterator = nullptr;
+ jmethodID m_iterator_has_next = nullptr;
+ jmethodID m_iterator_next = nullptr;
+ jmethodID m_entry_key = nullptr;
+ jmethodID m_entry_value = nullptr;
+ jmethodID m_map_put = nullptr;
+ jmethodID m_int_value = nullptr;
+ jmethodID m_float_value = nullptr;
+ jmethodID m_biconsumer_accept = nullptr;
+ jmethodID m_forname = nullptr;
+
+
+ // fields
+ jfieldID f_model_pointer = nullptr;
+ jfieldID f_task_id = nullptr;
+ jfieldID f_utf_8 = nullptr;
+ jfieldID f_iter_has_next = nullptr;
+ jfieldID f_log_level_debug = nullptr;
+ jfieldID f_log_level_info = nullptr;
+ jfieldID f_log_level_warn = nullptr;
+ jfieldID f_log_level_error = nullptr;
+ jfieldID f_log_format_json = nullptr;
+ jfieldID f_log_format_text = nullptr;
+
+ // objects
+ jobject o_utf_8 = nullptr;
+ jobject o_log_level_debug = nullptr;
+ jobject o_log_level_info = nullptr;
+ jobject o_log_level_warn = nullptr;
+ jobject o_log_level_error = nullptr;
+ jobject o_log_format_json = nullptr;
+ jobject o_log_format_text = nullptr;
+ jobject o_log_callback = nullptr;
+
+ /**
+ * Convert a Java string to a std::string
+ */
+ std::string parse_jstring(JNIEnv* env, jstring java_string) {
+ const char* utf_chars = env->GetStringUTFChars(java_string, nullptr);
+ if (utf_chars == nullptr) {
+ return "";
+ }
+
+ std::string result(utf_chars);
+ env->ReleaseStringUTFChars(java_string, utf_chars);
+
+ return result;
+ }
+
+ char ** parse_string_array(JNIEnv * env,
+ const jobjectArray string_array,
+ const jsize length) {
+ auto *
+ const result = static_cast < char ** > (malloc(length * sizeof(char * )));
if (result == nullptr) {
- return nullptr;
+ return nullptr;
}
for (jsize i = 0; i < length; i++) {
- auto *const javaString = static_cast(env->GetObjectArrayElement(string_array, i));
- const char *cString = env->GetStringUTFChars(javaString, nullptr);
- result[i] = strdup(cString);
- env->ReleaseStringUTFChars(javaString, cString);
+ auto *
+ const javaString = static_cast < jstring > (env -> GetObjectArrayElement(string_array, i));
+ const char * cString = env -> GetStringUTFChars(javaString, nullptr);
+ result[i] = strdup(cString);
+ env -> ReleaseStringUTFChars(javaString, cString);
}
return result;
-}
+ }
-void free_string_array(char **array, jsize length) {
+ void free_string_array(char ** array, jsize length) {
if (array != nullptr) {
- for (jsize i = 0; i < length; i++) {
- free(array[i]);
- }
- free(array);
+ for (jsize i = 0; i < length; i++) {
+ free(array[i]);
+ }
+ free(array);
}
-}
-
-/**
- * Since Java expects utf16 but std::strings are utf8, we can't directly use `env->NewString` or `env-NewString`,
- * but we directly send the bytes and do the conversion in Java. Unfortunately, there isn't a nice/standardized way to
- * do this conversion in C++
- */
-jbyteArray parse_jbytes(JNIEnv *env, const std::string &string) {
+ }
+
+ /**
+ * Since Java expects utf16 but std::strings are utf8, we can't directly use `env->NewString` or `env-NewString`,
+ * but we directly send the bytes and do the conversion in Java. Unfortunately, there isn't a nice/standardized way to
+ * do this conversion in C++
+ */
+ jbyteArray parse_jbytes(JNIEnv * env,
+ const std::string & string) {
jsize length = string.size(); // NOLINT(*-narrowing-conversions)
- jbyteArray bytes = env->NewByteArray(length);
- env->SetByteArrayRegion(bytes, 0, length, reinterpret_cast(string.c_str()));
+ jbyteArray bytes = env -> NewByteArray(length);
+ env -> SetByteArrayRegion(bytes, 0, length, reinterpret_cast <
+ const jbyte * > (string.c_str()));
return bytes;
-}
+ }
-/**
- * Map a llama.cpp log level to its Java enumeration option.
- */
-jobject log_level_to_jobject(ggml_log_level level) {
+ /**
+ * Map a llama.cpp log level to its Java enumeration option.
+ */
+ jobject log_level_to_jobject(ggml_log_level level) {
switch (level) {
case GGML_LOG_LEVEL_ERROR:
- return o_log_level_error;
+ return o_log_level_error;
case GGML_LOG_LEVEL_WARN:
- return o_log_level_warn;
+ return o_log_level_warn;
default:
case GGML_LOG_LEVEL_INFO:
- return o_log_level_info;
+ return o_log_level_info;
case GGML_LOG_LEVEL_DEBUG:
- return o_log_level_debug;
+ return o_log_level_debug;
}
-}
-
-/**
- * Returns the JNIEnv of the current thread.
- */
-JNIEnv *get_jni_env() {
- JNIEnv *env = nullptr;
- if (g_vm == nullptr || g_vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) {
- throw std::runtime_error("Thread is not attached to the JVM");
+ }
+
+ /**
+ * Returns the JNIEnv of the current thread.
+ */
+ JNIEnv * get_jni_env() {
+ JNIEnv * env = nullptr;
+ if (g_vm == nullptr || g_vm -> GetEnv(reinterpret_cast < void ** > ( & env), JNI_VERSION_1_6) != JNI_OK) {
+ throw std::runtime_error("Thread is not attached to the JVM");
}
return env;
-}
-
-bool log_json;
-std::function log_callback;
-
+ }
+
+ bool log_json;
+ std:: function < void(ggml_log_level,
+ const char * , void * ) > log_callback;
+
+ /**
+ * Format a log message as JSON
+ */
+ std::string format_log_as_json(ggml_log_level level, const char* text) {
+ std::string level_str;
+ switch (level) {
+ case GGML_LOG_LEVEL_ERROR: level_str = "ERROR"; break;
+ case GGML_LOG_LEVEL_WARN: level_str = "WARN"; break;
+ case GGML_LOG_LEVEL_INFO: level_str = "INFO"; break;
+ default:
+ case GGML_LOG_LEVEL_DEBUG: level_str = "DEBUG"; break;
+ }
+
+ // Create a JSON object with timestamp, level, and message
+ nlohmann::json log_json = {
+ {"timestamp", std::time(nullptr)},
+ {"level", level_str},
+ {"message", text}
+ };
+
+ return log_json.dump();
+ }
+ /**
+ * Invoke the log callback if there is any.
+ */
/**
* Invoke the log callback if there is any.
*/
-void log_callback_trampoline(ggml_log_level level, const char *text, void *user_data) {
- if (log_callback != nullptr) {
- log_callback(level, text, user_data);
- }
-}
+ void log_callback_trampoline(ggml_log_level level, const char* text, void* user_data) {
+ if (log_callback != nullptr) {
+ if (log_json) {
+ // Format the message as JSON before passing to callback
+ std::string json_text = format_log_as_json(level, text);
+ log_callback(level, json_text.c_str(), user_data);
+ } else {
+ // Pass the original text
+ log_callback(level, text, user_data);
+ }
+ }
+ }
} // namespace
/**
@@ -182,136 +218,133 @@ void log_callback_trampoline(ggml_log_level level, const char *text, void *user_
* only requires JNI version `JNI_VERSION_1_1`. If the VM does not recognize the version number returned by
`JNI_OnLoad`, the VM will unload the library and act as if the library was never loaded.
*/
-JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
- g_vm = vm;
- JNIEnv *env = nullptr;
-
- if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_1)) {
- goto error;
- }
-
- // find classes
- c_llama_model = env->FindClass("de/kherud/llama/LlamaModel");
- c_llama_iterator = env->FindClass("de/kherud/llama/LlamaIterator");
- c_standard_charsets = env->FindClass("java/nio/charset/StandardCharsets");
- c_output = env->FindClass("de/kherud/llama/LlamaOutput");
- c_string = env->FindClass("java/lang/String");
- c_hash_map = env->FindClass("java/util/HashMap");
- c_map = env->FindClass("java/util/Map");
- c_set = env->FindClass("java/util/Set");
- c_entry = env->FindClass("java/util/Map$Entry");
- c_iterator = env->FindClass("java/util/Iterator");
- c_integer = env->FindClass("java/lang/Integer");
- c_float = env->FindClass("java/lang/Float");
- c_biconsumer = env->FindClass("java/util/function/BiConsumer");
- c_llama_error = env->FindClass("de/kherud/llama/LlamaException");
- c_log_level = env->FindClass("de/kherud/llama/LogLevel");
- c_log_format = env->FindClass("de/kherud/llama/args/LogFormat");
- c_error_oom = env->FindClass("java/lang/OutOfMemoryError");
-
- if (!(c_llama_model && c_llama_iterator && c_standard_charsets && c_output && c_string && c_hash_map && c_map &&
- c_set && c_entry && c_iterator && c_integer && c_float && c_biconsumer && c_llama_error && c_log_level &&
- c_log_format && c_error_oom)) {
- goto error;
- }
-
- // create references
- c_llama_model = (jclass)env->NewGlobalRef(c_llama_model);
- c_llama_iterator = (jclass)env->NewGlobalRef(c_llama_iterator);
- c_output = (jclass)env->NewGlobalRef(c_output);
- c_string = (jclass)env->NewGlobalRef(c_string);
- c_hash_map = (jclass)env->NewGlobalRef(c_hash_map);
- c_map = (jclass)env->NewGlobalRef(c_map);
- c_set = (jclass)env->NewGlobalRef(c_set);
- c_entry = (jclass)env->NewGlobalRef(c_entry);
- c_iterator = (jclass)env->NewGlobalRef(c_iterator);
- c_integer = (jclass)env->NewGlobalRef(c_integer);
- c_float = (jclass)env->NewGlobalRef(c_float);
- c_biconsumer = (jclass)env->NewGlobalRef(c_biconsumer);
- c_llama_error = (jclass)env->NewGlobalRef(c_llama_error);
- c_log_level = (jclass)env->NewGlobalRef(c_log_level);
- c_log_format = (jclass)env->NewGlobalRef(c_log_format);
- c_error_oom = (jclass)env->NewGlobalRef(c_error_oom);
-
- // find constructors
- cc_output = env->GetMethodID(c_output, "", "([BLjava/util/Map;Z)V");
- cc_hash_map = env->GetMethodID(c_hash_map, "", "()V");
- cc_integer = env->GetMethodID(c_integer, "", "(I)V");
- cc_float = env->GetMethodID(c_float, "", "(F)V");
-
- if (!(cc_output && cc_hash_map && cc_integer && cc_float)) {
- goto error;
- }
-
- // find methods
- m_get_bytes = env->GetMethodID(c_string, "getBytes", "(Ljava/lang/String;)[B");
- m_entry_set = env->GetMethodID(c_map, "entrySet", "()Ljava/util/Set;");
- m_set_iterator = env->GetMethodID(c_set, "iterator", "()Ljava/util/Iterator;");
- m_iterator_has_next = env->GetMethodID(c_iterator, "hasNext", "()Z");
- m_iterator_next = env->GetMethodID(c_iterator, "next", "()Ljava/lang/Object;");
- m_entry_key = env->GetMethodID(c_entry, "getKey", "()Ljava/lang/Object;");
- m_entry_value = env->GetMethodID(c_entry, "getValue", "()Ljava/lang/Object;");
- m_map_put = env->GetMethodID(c_map, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
- m_int_value = env->GetMethodID(c_integer, "intValue", "()I");
- m_float_value = env->GetMethodID(c_float, "floatValue", "()F");
- m_biconsumer_accept = env->GetMethodID(c_biconsumer, "accept", "(Ljava/lang/Object;Ljava/lang/Object;)V");
-
- if (!(m_get_bytes && m_entry_set && m_set_iterator && m_iterator_has_next && m_iterator_next && m_entry_key &&
- m_entry_value && m_map_put && m_int_value && m_float_value && m_biconsumer_accept)) {
- goto error;
- }
-
- // find fields
- f_model_pointer = env->GetFieldID(c_llama_model, "ctx", "J");
- f_task_id = env->GetFieldID(c_llama_iterator, "taskId", "I");
- f_utf_8 = env->GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;");
- f_iter_has_next = env->GetFieldID(c_llama_iterator, "hasNext", "Z");
- f_log_level_debug = env->GetStaticFieldID(c_log_level, "DEBUG", "Lde/kherud/llama/LogLevel;");
- f_log_level_info = env->GetStaticFieldID(c_log_level, "INFO", "Lde/kherud/llama/LogLevel;");
- f_log_level_warn = env->GetStaticFieldID(c_log_level, "WARN", "Lde/kherud/llama/LogLevel;");
- f_log_level_error = env->GetStaticFieldID(c_log_level, "ERROR", "Lde/kherud/llama/LogLevel;");
- f_log_format_json = env->GetStaticFieldID(c_log_format, "JSON", "Lde/kherud/llama/args/LogFormat;");
- f_log_format_text = env->GetStaticFieldID(c_log_format, "TEXT", "Lde/kherud/llama/args/LogFormat;");
-
- if (!(f_model_pointer && f_task_id && f_utf_8 && f_iter_has_next && f_log_level_debug && f_log_level_info &&
- f_log_level_warn && f_log_level_error && f_log_format_json && f_log_format_text)) {
- goto error;
- }
-
- o_utf_8 = env->NewStringUTF("UTF-8");
- o_log_level_debug = env->GetStaticObjectField(c_log_level, f_log_level_debug);
- o_log_level_info = env->GetStaticObjectField(c_log_level, f_log_level_info);
- o_log_level_warn = env->GetStaticObjectField(c_log_level, f_log_level_warn);
- o_log_level_error = env->GetStaticObjectField(c_log_level, f_log_level_error);
- o_log_format_json = env->GetStaticObjectField(c_log_format, f_log_format_json);
- o_log_format_text = env->GetStaticObjectField(c_log_format, f_log_format_text);
-
- if (!(o_utf_8 && o_log_level_debug && o_log_level_info && o_log_level_warn && o_log_level_error &&
- o_log_format_json && o_log_format_text)) {
- goto error;
- }
-
- o_utf_8 = env->NewGlobalRef(o_utf_8);
- o_log_level_debug = env->NewGlobalRef(o_log_level_debug);
- o_log_level_info = env->NewGlobalRef(o_log_level_info);
- o_log_level_warn = env->NewGlobalRef(o_log_level_warn);
- o_log_level_error = env->NewGlobalRef(o_log_level_error);
- o_log_format_json = env->NewGlobalRef(o_log_format_json);
- o_log_format_text = env->NewGlobalRef(o_log_format_text);
-
- if (env->ExceptionCheck()) {
- env->ExceptionDescribe();
- goto error;
- }
-
- llama_backend_init();
-
- goto success;
-
-error:
+JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM * vm, void * reserved) {
+ g_vm = vm;
+ JNIEnv * env = nullptr;
+
+ if (JNI_OK != vm -> GetEnv((void ** ) & env, JNI_VERSION_1_1)) {
+ goto error;
+ }
+
+ // find classes
+ c_charset_class = env->FindClass("java/nio/charset/Charset");
+ c_llama_model = env -> FindClass("de/kherud/llama/LlamaModel");
+ c_standard_charsets = env -> FindClass("java/nio/charset/StandardCharsets");
+ c_string = env -> FindClass("java/lang/String");
+ c_hash_map = env -> FindClass("java/util/HashMap");
+ c_map = env -> FindClass("java/util/Map");
+ c_set = env -> FindClass("java/util/Set");
+ c_entry = env -> FindClass("java/util/Map$Entry");
+ c_iterator = env -> FindClass("java/util/Iterator");
+ c_integer = env -> FindClass("java/lang/Integer");
+ c_float = env -> FindClass("java/lang/Float");
+ c_biconsumer = env -> FindClass("java/util/function/BiConsumer");
+ c_llama_error = env -> FindClass("de/kherud/llama/LlamaException");
+ c_log_level = env -> FindClass("de/kherud/llama/LogLevel");
+ c_log_format = env -> FindClass("de/kherud/llama/args/LogFormat");
+ c_error_oom = env -> FindClass("java/lang/OutOfMemoryError");
+
+
+ if (!(c_llama_model && c_charset_class && c_standard_charsets && c_string && c_hash_map && c_map &&
+ c_set && c_entry && c_iterator && c_integer && c_float && c_biconsumer && c_llama_error && c_log_level &&
+ c_log_format && c_error_oom)) {
+ goto error;
+ }
+
+ // create references
+ c_charset_class = (jclass) env -> NewGlobalRef(c_charset_class);
+ c_llama_model = (jclass) env -> NewGlobalRef(c_llama_model);
+ c_string = (jclass) env -> NewGlobalRef(c_string);
+ c_hash_map = (jclass) env -> NewGlobalRef(c_hash_map);
+ c_map = (jclass) env -> NewGlobalRef(c_map);
+ c_set = (jclass) env -> NewGlobalRef(c_set);
+ c_entry = (jclass) env -> NewGlobalRef(c_entry);
+ c_iterator = (jclass) env -> NewGlobalRef(c_iterator);
+ c_integer = (jclass) env -> NewGlobalRef(c_integer);
+ c_float = (jclass) env -> NewGlobalRef(c_float);
+ c_biconsumer = (jclass) env -> NewGlobalRef(c_biconsumer);
+ c_llama_error = (jclass) env -> NewGlobalRef(c_llama_error);
+ c_log_level = (jclass) env -> NewGlobalRef(c_log_level);
+ c_log_format = (jclass) env -> NewGlobalRef(c_log_format);
+ c_error_oom = (jclass) env -> NewGlobalRef(c_error_oom);
+
+ // find constructors
+ cc_hash_map = env -> GetMethodID(c_hash_map, "", "()V");
+ cc_integer = env -> GetMethodID(c_integer, "", "(I)V");
+ cc_float = env -> GetMethodID(c_float, "", "(F)V");
+
+ if (!(cc_hash_map && cc_integer && cc_float)) {
+ goto error;
+ }
+
+ // find methods
+ m_get_bytes = env -> GetMethodID(c_string, "getBytes", "(Ljava/lang/String;)[B");
+ m_entry_set = env -> GetMethodID(c_map, "entrySet", "()Ljava/util/Set;");
+ m_set_iterator = env -> GetMethodID(c_set, "iterator", "()Ljava/util/Iterator;");
+ m_iterator_has_next = env -> GetMethodID(c_iterator, "hasNext", "()Z");
+ m_iterator_next = env -> GetMethodID(c_iterator, "next", "()Ljava/lang/Object;");
+ m_entry_key = env -> GetMethodID(c_entry, "getKey", "()Ljava/lang/Object;");
+ m_entry_value = env -> GetMethodID(c_entry, "getValue", "()Ljava/lang/Object;");
+ m_map_put = env -> GetMethodID(c_map, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
+ m_int_value = env -> GetMethodID(c_integer, "intValue", "()I");
+ m_float_value = env -> GetMethodID(c_float, "floatValue", "()F");
+ m_biconsumer_accept = env -> GetMethodID(c_biconsumer, "accept", "(Ljava/lang/Object;Ljava/lang/Object;)V");
+ m_forname = env->GetStaticMethodID(c_charset_class, "forName", "(Ljava/lang/String;)Ljava/nio/charset/Charset;");
+
+ if (!(m_get_bytes && m_entry_set && m_set_iterator && m_iterator_has_next && m_iterator_next && m_entry_key &&
+ m_entry_value && m_map_put && m_int_value && m_float_value && m_biconsumer_accept && m_forname)) {
+ goto error;
+ }
+
+ // find fields
+ f_model_pointer = env -> GetFieldID(c_llama_model, "ctx", "J");
+ f_utf_8 = env -> GetStaticFieldID(c_standard_charsets, "UTF_8", "Ljava/nio/charset/Charset;");
+ f_log_level_debug = env -> GetStaticFieldID(c_log_level, "DEBUG", "Lde/kherud/llama/LogLevel;");
+ f_log_level_info = env -> GetStaticFieldID(c_log_level, "INFO", "Lde/kherud/llama/LogLevel;");
+ f_log_level_warn = env -> GetStaticFieldID(c_log_level, "WARN", "Lde/kherud/llama/LogLevel;");
+ f_log_level_error = env -> GetStaticFieldID(c_log_level, "ERROR", "Lde/kherud/llama/LogLevel;");
+ f_log_format_json = env -> GetStaticFieldID(c_log_format, "JSON", "Lde/kherud/llama/args/LogFormat;");
+ f_log_format_text = env -> GetStaticFieldID(c_log_format, "TEXT", "Lde/kherud/llama/args/LogFormat;");
+
+ if (!(f_model_pointer && f_utf_8 && f_log_level_debug && f_log_level_info &&
+ f_log_level_warn && f_log_level_error && f_log_format_json && f_log_format_text)) {
+ goto error;
+ }
+
+ o_utf_8 = env -> NewStringUTF("UTF-8");
+ o_log_level_debug = env -> GetStaticObjectField(c_log_level, f_log_level_debug);
+ o_log_level_info = env -> GetStaticObjectField(c_log_level, f_log_level_info);
+ o_log_level_warn = env -> GetStaticObjectField(c_log_level, f_log_level_warn);
+ o_log_level_error = env -> GetStaticObjectField(c_log_level, f_log_level_error);
+ o_log_format_json = env -> GetStaticObjectField(c_log_format, f_log_format_json);
+ o_log_format_text = env -> GetStaticObjectField(c_log_format, f_log_format_text);
+
+ if (!(o_utf_8 && o_log_level_debug && o_log_level_info && o_log_level_warn && o_log_level_error &&
+ o_log_format_json && o_log_format_text)) {
+ goto error;
+ }
+
+ o_utf_8 = env -> NewGlobalRef(o_utf_8);
+ o_log_level_debug = env -> NewGlobalRef(o_log_level_debug);
+ o_log_level_info = env -> NewGlobalRef(o_log_level_info);
+ o_log_level_warn = env -> NewGlobalRef(o_log_level_warn);
+ o_log_level_error = env -> NewGlobalRef(o_log_level_error);
+ o_log_format_json = env -> NewGlobalRef(o_log_format_json);
+ o_log_format_text = env -> NewGlobalRef(o_log_format_text);
+
+ if (env -> ExceptionCheck()) {
+ env -> ExceptionDescribe();
+ goto error;
+ }
+
+ llama_backend_init();
+
+ goto success;
+
+ error:
return JNI_ERR;
-success:
+ success:
return JNI_VERSION_1_6;
}
@@ -323,57 +356,62 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
* Note that `JNI_OnLoad` and `JNI_OnUnload` are two functions optionally supplied by JNI libraries, not exported from
* the VM.
*/
-JNIEXPORT void JNICALL JNI_OnUnload(JavaVM *vm, void *reserved) {
- JNIEnv *env = nullptr;
-
- if (JNI_OK != vm->GetEnv((void **)&env, JNI_VERSION_1_6)) {
- return;
- }
-
- env->DeleteGlobalRef(c_llama_model);
- env->DeleteGlobalRef(c_llama_iterator);
- env->DeleteGlobalRef(c_output);
- env->DeleteGlobalRef(c_string);
- env->DeleteGlobalRef(c_hash_map);
- env->DeleteGlobalRef(c_map);
- env->DeleteGlobalRef(c_set);
- env->DeleteGlobalRef(c_entry);
- env->DeleteGlobalRef(c_iterator);
- env->DeleteGlobalRef(c_integer);
- env->DeleteGlobalRef(c_float);
- env->DeleteGlobalRef(c_biconsumer);
- env->DeleteGlobalRef(c_llama_error);
- env->DeleteGlobalRef(c_log_level);
- env->DeleteGlobalRef(c_log_level);
- env->DeleteGlobalRef(c_error_oom);
-
- env->DeleteGlobalRef(o_utf_8);
- env->DeleteGlobalRef(o_log_level_debug);
- env->DeleteGlobalRef(o_log_level_info);
- env->DeleteGlobalRef(o_log_level_warn);
- env->DeleteGlobalRef(o_log_level_error);
- env->DeleteGlobalRef(o_log_format_json);
- env->DeleteGlobalRef(o_log_format_text);
-
- if (o_log_callback != nullptr) {
- env->DeleteGlobalRef(o_log_callback);
- }
-
- llama_backend_free();
+JNIEXPORT void JNICALL JNI_OnUnload(JavaVM * vm, void * reserved) {
+ JNIEnv * env = nullptr;
+
+ if (JNI_OK != vm -> GetEnv((void ** ) & env, JNI_VERSION_1_6)) {
+ return;
+ }
+
+ env -> DeleteGlobalRef(c_llama_model);
+ env -> DeleteGlobalRef(c_charset_class);
+ env -> DeleteGlobalRef(c_string);
+ env -> DeleteGlobalRef(c_hash_map);
+ env -> DeleteGlobalRef(c_map);
+ env -> DeleteGlobalRef(c_set);
+ env -> DeleteGlobalRef(c_entry);
+ env -> DeleteGlobalRef(c_iterator);
+ env -> DeleteGlobalRef(c_integer);
+ env -> DeleteGlobalRef(c_float);
+ env -> DeleteGlobalRef(c_biconsumer);
+ env -> DeleteGlobalRef(c_llama_error);
+ env -> DeleteGlobalRef(c_log_level);
+ env -> DeleteGlobalRef(c_log_level);
+ env -> DeleteGlobalRef(c_error_oom);
+
+ env -> DeleteGlobalRef(o_utf_8);
+ env -> DeleteGlobalRef(o_log_level_debug);
+ env -> DeleteGlobalRef(o_log_level_info);
+ env -> DeleteGlobalRef(o_log_level_warn);
+ env -> DeleteGlobalRef(o_log_level_error);
+ env -> DeleteGlobalRef(o_log_format_json);
+ env -> DeleteGlobalRef(o_log_format_text);
+
+ if (o_log_callback != nullptr) {
+ env -> DeleteGlobalRef(o_log_callback);
+ }
+
+ llama_backend_free();
}
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jobject obj, jobjectArray jparams) {
+/**
+ * Load a model with the given parameters.
+ * This function initializes the server context and loads the language model.
+ */
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv* env, jobject obj, jobjectArray jparams) {
common_params params;
const jsize argc = env->GetArrayLength(jparams);
- char **argv = parse_string_array(env, jparams, argc);
+ char** argv = parse_string_array(env, jparams, argc);
if (argv == nullptr) {
+ env->ThrowNew(c_error_oom, "Failed to allocate memory for parameters");
return;
}
const auto parsed_params = common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER);
free_string_array(argv, argc);
if (!parsed_params) {
+ env->ThrowNew(c_llama_error, "Failed to parse parameters");
return;
}
@@ -381,38 +419,43 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
common_init();
- // struct that contains llama context and inference
- auto *ctx_server = new server_context();
+ // Create server context structure that contains llama context and inference
+ auto* ctx_server = new server_context();
+ // Initialize NUMA if configured
llama_numa_init(params.numa);
- LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads,
- params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
+ // Log system information
+ LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n",
+ params.cpuparams.n_threads, params.cpuparams_batch.n_threads,
+ std::thread::hardware_concurrency());
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
LOG_INF("\n");
+ // Initialize server state
std::atomic state{SERVER_STATE_LOADING_MODEL};
- // Necessary similarity of prompt for slot selection
+ // Set prompt similarity threshold for slot selection
ctx_server->slot_prompt_similarity = params.slot_prompt_similarity;
LOG_INF("%s: loading model\n", __func__);
- // load the model
+ // Load the model
if (!ctx_server->load_model(params)) {
+ delete ctx_server;
llama_backend_free();
- env->ThrowNew(c_llama_error, "could not load model from given file path");
+ env->ThrowNew(c_llama_error, "Could not load model from given file path");
return;
}
+ // Initialize the server context
ctx_server->init();
state.store(SERVER_STATE_READY);
LOG_INF("%s: model loaded\n", __func__);
- const auto model_meta = ctx_server->model_meta();
-
+ // Load draft model if configured (for speculative decoding)
if (!params.speculative.model.empty() || !params.speculative.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str());
auto params_dft = params;
@@ -427,61 +470,55 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
params_dft.n_parallel = 1;
common_init_result llama_init_dft = common_init_from_params(params_dft);
-
- llama_model *model_dft = llama_init_dft.model.get();
+ llama_model* model_dft = llama_init_dft.model.get();
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params.speculative.model.c_str());
+ } else {
+ if (!common_speculative_are_compatible(ctx_server->ctx, llama_init_dft.context.get())) {
+ SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n",
+ params.speculative.model.c_str(), params.model.c_str());
+ } else {
+ const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
+ ctx_server->cparams_dft = common_context_params_to_llama(params_dft);
+ ctx_server->cparams_dft.n_batch = n_ctx_dft;
+
+ // force F16 KV cache for the draft model for extra performance
+ ctx_server->cparams_dft.type_k = GGML_TYPE_F16;
+ ctx_server->cparams_dft.type_v = GGML_TYPE_F16;
+
+ // the context is not needed - we will create one for each slot
+ llama_init_dft.context.reset();
+ }
}
-
- if (!common_speculative_are_compatible(ctx_server->ctx, llama_init_dft.context.get())) {
- SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n",
- params.speculative.model.c_str(), params.model.c_str());
- }
-
- const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
-
- ctx_server->cparams_dft = common_context_params_to_llama(params_dft);
- ctx_server->cparams_dft.n_batch = n_ctx_dft;
-
- // force F16 KV cache for the draft model for extra performance
- ctx_server->cparams_dft.type_k = GGML_TYPE_F16;
- ctx_server->cparams_dft.type_v = GGML_TYPE_F16;
-
- // the context is not needed - we will create one for each slot
- llama_init_dft.context.reset();
}
+ // Initialize chat templates
ctx_server->chat_templates = common_chat_templates_init(ctx_server->model, params.chat_template);
try {
common_chat_format_example(ctx_server->chat_templates.get(), params.use_jinja);
- } catch (const std::exception &e) {
+ } catch (const std::exception& e) {
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This "
- "may cause the model to output suboptimal responses\n",
- __func__);
+ "may cause the model to output suboptimal responses\n", __func__);
ctx_server->chat_templates = common_chat_templates_init(ctx_server->model, "chatml");
}
- // print sample chat example to make it clear which template is used
+ // Print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
- common_chat_templates_source(ctx_server->chat_templates.get()),
- common_chat_format_example(ctx_server->chat_templates.get(), ctx_server->params_base.use_jinja).c_str());
-
- // print sample chat example to make it clear which template is used
- // LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
- // common_chat_templates_source(ctx_server->chat_templates.get()),
- // common_chat_format_example(*ctx_server->chat_templates.template_default,
- // ctx_server->params_base.use_jinja) .c_str());
+ common_chat_templates_source(ctx_server->chat_templates.get()),
+ common_chat_format_example(ctx_server->chat_templates.get(), ctx_server->params_base.use_jinja).c_str());
+ // Set up task handlers
ctx_server->queue_tasks.on_new_task(
std::bind(&server_context::process_single_task, ctx_server, std::placeholders::_1));
ctx_server->queue_tasks.on_update_slots(std::bind(&server_context::update_slots, ctx_server));
+ // Start task processing thread
std::thread t([ctx_server]() {
- JNIEnv *env;
- jint res = g_vm->GetEnv((void **)&env, JNI_VERSION_1_6);
+ JNIEnv* env;
+ jint res = g_vm->GetEnv((void**)&env, JNI_VERSION_1_6);
if (res == JNI_EDETACHED) {
- res = g_vm->AttachCurrentThread((void **)&env, nullptr);
+ res = g_vm->AttachCurrentThread((void**)&env, nullptr);
if (res != JNI_OK) {
throw std::runtime_error("Failed to attach thread to JVM");
}
@@ -490,374 +527,1913 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
});
t.detach();
+ // Store server context pointer in Java object
env->SetLongField(obj, f_model_pointer, reinterpret_cast(ctx_server));
}
-JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv *env, jobject obj, jstring jparams) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
-
- std::string c_params = parse_jstring(env, jparams);
- json data = json::parse(c_params);
+/**
+ * Clean up resources and delete the model.
+ * This function shuts down the server context and frees memory.
+ */
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv* env, jobject obj) {
+ try {
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ return; // Already deleted or not initialized
+ }
- server_task_type type = SERVER_TASK_TYPE_COMPLETION;
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Log shutdown
+ SRV_INF("%s: cleaning up before exit...\n", __func__);
+
+ // Cancel all pending tasks
+ ctx_server->queue_tasks.terminate();
+
+ // Wait for a brief moment to allow tasks to complete
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ // Delete the server context
+ delete ctx_server;
+
+ // Clear the pointer in Java
+ env->SetLongField(obj, f_model_pointer, 0);
+
+ SRV_INF("%s: cleanup complete\n", __func__);
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception during shutdown: %s\n", e.what());
+ // We don't throw here, as this would prevent proper cleanup during JVM shutdown
+ }
+}
- if (data.contains("input_prefix") || data.contains("input_suffix")) {
- type = SERVER_TASK_TYPE_INFILL;
+/**
+ * Set a logger for llama.cpp logs.
+ * This function configures the logging system to forward messages to Java.
+ */
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv* env, jclass clazz, jobject log_format, jobject jcallback) {
+ if (o_log_callback != nullptr) {
+ env->DeleteGlobalRef(o_log_callback);
+ o_log_callback = nullptr;
}
- auto completion_id = gen_chatcmplid();
- std::vector tasks;
+ log_json = env->IsSameObject(log_format, o_log_format_json);
+
+ if (jcallback == nullptr) {
+ // Disable logging if callback is null
+ log_callback = nullptr;
+ llama_log_set(nullptr, nullptr);
+ } else {
+ // Store a global reference to the callback object
+ o_log_callback = env->NewGlobalRef(jcallback);
+
+ // Create a C++ callback function that forwards to Java
+ log_callback = [](enum ggml_log_level level, const char* text, void* user_data) {
+ JNIEnv* env = get_jni_env();
+ jstring message = env->NewStringUTF(text);
+ jobject log_level = log_level_to_jobject(level);
+ env->CallVoidMethod(o_log_callback, m_biconsumer_accept, log_level, message);
+ env->DeleteLocalRef(message);
+ };
+
+ // Always set the logger, regardless of JSON format
+ llama_log_set(log_callback_trampoline, nullptr);
+
+ // For debugging, send an initial log message
+ LOG_INF("Logger initialized (JSON format: %s)\n", log_json ? "true" : "false");
+
+ }
+}
+/**
+ * Handle standard completions request.
+ * Equivalent to POST /completions endpoint.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletions(JNIEnv* env, jobject obj, jstring jrequestData, jboolean jstream) {
try {
- const auto &prompt = data.at("prompt");
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
- std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true);
+ auto* ctx_server = reinterpret_cast(server_handle);
- tasks.reserve(tokenized_prompts.size());
- for (size_t i = 0; i < tokenized_prompts.size(); i++) {
- server_task task = server_task(type);
+ // Check if embeddings mode is active (which would prevent completions)
+ if (ctx_server->params_base.embedding) {
+ env->ThrowNew(c_llama_error, "This server does not support completions. Start it without `--embeddings`");
+ return nullptr;
+ }
- task.id = ctx_server->queue_tasks.get_new_id();
- task.index = i;
+ // Parse request data from JSON
+ std::string request_str = parse_jstring(env, jrequestData);
+ json data = json::parse(request_str);
+
+ // Set streaming flag
+ bool stream = jstream;
+ data["stream"] = stream;
+
+ // Create a completion ID
+ auto completion_id = gen_chatcmplid();
+ std::vector tasks;
+
+ try {
+ // Extract prompt from request data
+ const auto& prompt = data.at("prompt");
+
+ // Tokenize prompt
+ std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true);
+
+ // Create tasks for each tokenized prompt
+ tasks.reserve(tokenized_prompts.size());
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ server_task task(SERVER_TASK_TYPE_COMPLETION);
+
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.index = i;
+
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
+ task.params = server_task::params_from_json_cmpl(
+ ctx_server->ctx, ctx_server->params_base, data);
+
+ task.id_selected_slot = json_value(data, "id_slot", -1);
+
+ // Set completion ID (but not OAI compatibility for standard completion)
+ task.params.oaicompat = OAICOMPAT_TYPE_NONE;
+ task.params.oaicompat_cmpl_id = completion_id;
+
+ tasks.push_back(task);
+ }
+ } catch (const std::exception& e) {
+ const auto& err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST);
+ env->ThrowNew(c_llama_error, err.dump().c_str());
+ return nullptr;
+ }
- task.prompt_tokens = std::move(tokenized_prompts[i]);
- task.params = server_task::params_from_json_cmpl(ctx_server->ctx, ctx_server->params_base, data);
- task.id_selected_slot = json_value(data, "id_slot", -1);
+ // Add tasks to waiting queue and post them for processing
+ ctx_server->queue_results.add_waiting_tasks(tasks);
+ ctx_server->queue_tasks.post(tasks);
- // OAI-compat
- task.params.oaicompat = OAICOMPAT_TYPE_NONE;
- task.params.oaicompat_cmpl_id = completion_id;
- // oaicompat_model is already populated by params_from_json_cmpl
+ // Get task IDs
+ const auto task_ids = server_task::get_list_id(tasks);
- tasks.push_back(task);
- }
- } catch (const std::exception &e) {
- const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST);
- env->ThrowNew(c_llama_error, err.dump().c_str());
- return 0;
- }
+ // Create response JSON
+ json response;
- ctx_server->queue_results.add_waiting_tasks(tasks);
- ctx_server->queue_tasks.post(tasks);
+ if (!stream) {
+ // For non-streaming, collect all results
+ std::vector results;
+ results.reserve(tasks.size());
- const auto task_ids = server_task::get_list_id(tasks);
+ for (size_t i = 0; i < tasks.size(); i++) {
+ server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
- if (task_ids.size() != 1) {
- env->ThrowNew(c_llama_error, "multitasking currently not supported");
- return 0;
- }
+ if (result->is_error()) {
+ // Clean up and throw error
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
- return *task_ids.begin();
-}
+ results.push_back(std::move(result));
+ }
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *env, jobject obj, jint id_task) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
- ctx_server->queue_results.remove_waiting_task_id(id_task);
-}
+ // Format the response
+ response["type"] = "completion";
+ response["streaming"] = false;
+ response["completion_id"] = completion_id;
+
+ if (results.size() == 1) {
+ // Single result - preserve all the data including token probabilities
+ auto result_json = results[0]->to_json();
+
+ // Check if this is a final completion result that might have probabilities
+ auto* cmpl_final = dynamic_cast(results[0].get());
+
+ if (cmpl_final != nullptr && !cmpl_final->probs_output.empty() && cmpl_final->post_sampling_probs) {
+ // Make sure the token probabilities are included
+ result_json["completion_probabilities"] =
+ completion_token_output::probs_vector_to_json(cmpl_final->probs_output,
+ cmpl_final->post_sampling_probs);
+ }
+
+ response["result"] = result_json;
+ } else {
+ // Multiple results
+ json results_array = json::array();
+ for (auto& res: results) {
+ auto result_json = res->to_json();
+
+ // Check for token probabilities in each result
+ auto* cmpl_final = dynamic_cast(res.get());
+
+ if (cmpl_final != nullptr && !cmpl_final->probs_output.empty() && cmpl_final->post_sampling_probs) {
+ // Make sure the token probabilities are included
+ result_json["completion_probabilities"] =
+ completion_token_output::probs_vector_to_json(cmpl_final->probs_output,
+ cmpl_final->post_sampling_probs);
+ }
+
+ results_array.push_back(result_json);
+ }
+ response["results"] = results_array;
+ }
+
+ // Clean up
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ } else {
+ // For streaming, return the task IDs
+ response["type"] = "stream_init";
+ response["streaming"] = true;
+ response["completion_id"] = completion_id;
+
+ // Convert set to array
+ json task_ids_array = json::array();
+ for (const auto& id: task_ids) {
+ task_ids_array.push_back(id);
+ }
+ response["task_ids"] = task_ids_array;
-JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *env, jobject obj, jint id_task) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
+ SRV_INF("Started streaming completion with %zu task(s)\n", task_ids.size());
+ }
- server_task_result_ptr result = ctx_server->queue_results.recv(id_task);
+ // Return the response as a JSON string
+ std::string response_str = response.dump();
+ jstring result = env->NewStringUTF(response_str.c_str());
- if (result->is_error()) {
- std::string response = result->to_json()["message"].get();
- ctx_server->queue_results.remove_waiting_task_id(id_task);
- env->ThrowNew(c_llama_error, response.c_str());
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleCompletions: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
return nullptr;
}
- const auto out_res = result->to_json();
+}
- std::string response = out_res["content"].get();
- if (result->is_stop()) {
- ctx_server->queue_results.remove_waiting_task_id(id_task);
- }
+/**
+ * Handle OpenAI compatible completions request.
+ * Equivalent to POST /v1/completions endpoint.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletionsOai(JNIEnv* env, jobject obj, jstring jrequestData, jboolean jstream) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
- jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
- if (out_res.contains("completion_probabilities")) {
- auto completion_probabilities = out_res["completion_probabilities"];
- for (const auto &entry : completion_probabilities) {
- auto probs = entry["probs"];
- for (const auto &tp : probs) {
- std::string tok_str = tp["tok_str"];
- jstring jtok_str = env->NewStringUTF(tok_str.c_str());
- float prob = tp["prob"];
- jobject jprob = env->NewObject(c_float, cc_float, prob);
- env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob);
- env->DeleteLocalRef(jtok_str);
- env->DeleteLocalRef(jprob);
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Check if embeddings mode is active (which would prevent completions)
+ if (ctx_server->params_base.embedding) {
+ env->ThrowNew(c_llama_error, "This server does not support completions. Start it without `--embeddings`");
+ return nullptr;
+ }
+
+ // Parse request data from JSON
+ std::string request_str = parse_jstring(env, jrequestData);
+ json body = json::parse(request_str);
+
+ // Set streaming flag
+ bool stream = jstream;
+ body["stream"] = stream;
+
+ // Parse the OpenAI-compatible parameters
+ json data = oaicompat_completion_params_parse(body);
+
+ // Create a completion ID
+ auto completion_id = gen_chatcmplid();
+ std::vector tasks;
+
+ try {
+ // Extract prompt from request data
+ const auto& prompt = data.at("prompt");
+
+ // Tokenize prompt
+ std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true);
+
+ // Create tasks for each tokenized prompt
+ tasks.reserve(tokenized_prompts.size());
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ server_task task(SERVER_TASK_TYPE_COMPLETION);
+
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.index = i;
+
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
+ task.params = server_task::params_from_json_cmpl(
+ ctx_server->ctx, ctx_server->params_base, data);
+
+ task.id_selected_slot = json_value(data, "id_slot", -1);
+
+ // Set OAI compatibility mode
+ task.params.oaicompat = OAICOMPAT_TYPE_COMPLETION;
+ task.params.oaicompat_cmpl_id = completion_id;
+
+ tasks.push_back(task);
}
+ } catch (const std::exception& e) {
+ const auto& err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST);
+ env->ThrowNew(c_llama_error, err.dump().c_str());
+ return nullptr;
+ }
+
+ // Add tasks to waiting queue and post them for processing
+ ctx_server->queue_results.add_waiting_tasks(tasks);
+ ctx_server->queue_tasks.post(tasks);
+
+ // Get task IDs
+ const auto task_ids = server_task::get_list_id(tasks);
+
+ // Create response JSON
+ json response;
+
+ if (!stream) {
+ // For non-streaming, collect all results
+ std::vector results;
+ results.reserve(tasks.size());
+
+ for (size_t i = 0; i < tasks.size(); i++) {
+ server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
+
+ if (result->is_error()) {
+ // Clean up and throw error
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ results.push_back(std::move(result));
+ }
+
+ // Format the response
+ response["type"] = "oai_completion";
+ response["streaming"] = false;
+ response["completion_id"] = completion_id;
+
+ if (results.size() == 1) {
+ // Single result
+ response["result"] = results[0]->to_json();
+ } else {
+ // Multiple results
+ json results_array = json::array();
+ for (auto& res: results) {
+ results_array.push_back(res->to_json());
+ }
+ response["results"] = results_array;
+ }
+
+ // Clean up
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ } else {
+ // For streaming, return the task IDs
+ response["type"] = "oai_stream_init";
+ response["streaming"] = true;
+ response["completion_id"] = completion_id;
+
+ // Convert set to array
+ json task_ids_array = json::array();
+ for (const auto& id: task_ids) {
+ task_ids_array.push_back(id);
+ }
+ response["task_ids"] = task_ids_array;
+
+ SRV_INF("Started streaming OAI completion with %zu task(s)\n", task_ids.size());
}
- }
- jbyteArray jbytes = parse_jbytes(env, response);
- return env->NewObject(c_output, cc_output, jbytes, o_probabilities, result->is_stop());
-}
-JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
+ // Return the response as a JSON string
+ std::string response_str = response.dump();
+ jstring result = env->NewStringUTF(response_str.c_str());
- if (!ctx_server->params_base.embedding) {
- env->ThrowNew(c_llama_error,
- "model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))");
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleCompletionsOai: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
return nullptr;
}
+}
+
+/**
+ * Handle chat completions request.
+ * Equivalent to POST /chat/completions or POST /v1/chat/completions endpoints.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleChatCompletions(JNIEnv* env, jobject obj, jstring jrequestData, jboolean jstream) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
+
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Check if embeddings mode is active (which would prevent completions)
+ if (ctx_server->params_base.embedding) {
+ env->ThrowNew(c_llama_error, "This server does not support completions. Start it without `--embeddings`");
+ return nullptr;
+ }
+
+ // Parse request data from JSON
+ std::string request_str = parse_jstring(env, jrequestData);
+ json body = json::parse(request_str);
+
+ // Log debug information
+ LOG_DBG("Chat request: %s\n", request_str.c_str());
+
+ // Set streaming flag
+ bool stream = jstream;
+ body["stream"] = stream;
+
+ // Parse the OAI-compatible parameters with chat template application
+ json data = oaicompat_completion_params_parse(
+ body,
+ ctx_server->params_base.use_jinja,
+ ctx_server->params_base.reasoning_format,
+ ctx_server->chat_templates.get());
+
+ // Create a completion ID
+ auto completion_id = gen_chatcmplid();
+ std::vector tasks;
+
+ try {
+ // Extract prompt from processed data
+ const auto& prompt = data.at("prompt");
+
+ // Tokenize prompt
+ std::vector tokenized_prompts = tokenize_input_prompts(
+ ctx_server->vocab, prompt, true, true);
+
+ // Create tasks for each tokenized prompt
+ tasks.reserve(tokenized_prompts.size());
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ server_task task(SERVER_TASK_TYPE_COMPLETION);
+
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.index = i;
+
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
+ task.params = server_task::params_from_json_cmpl(
+ ctx_server->ctx, ctx_server->params_base, data);
+
+ task.id_selected_slot = json_value(data, "id_slot", -1);
+
+ // Set OAI chat compatibility mode
+ task.params.oaicompat = OAICOMPAT_TYPE_CHAT;
+ task.params.oaicompat_cmpl_id = completion_id;
+
+ tasks.push_back(task);
+ }
+ } catch (const std::exception& e) {
+ const auto& err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST);
+ env->ThrowNew(c_llama_error, err.dump().c_str());
+ return nullptr;
+ }
- const std::string prompt = parse_jstring(env, jprompt);
+ // Add tasks to waiting queue and post them for processing
+ ctx_server->queue_results.add_waiting_tasks(tasks);
+ ctx_server->queue_tasks.post(tasks);
- SRV_INF("Calling embedding '%s'\n", prompt.c_str());
+ // Get task IDs
+ const auto task_ids = server_task::get_list_id(tasks);
- const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true);
- std::vector tasks;
+ // Create response JSON
+ json response;
- server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
+ if (!stream) {
+ // For non-streaming, collect all results
+ std::vector results;
+ results.reserve(tasks.size());
- task.id = ctx_server->queue_tasks.get_new_id();
- task.index = 0;
- task.prompt_tokens = std::move(tokens);
+ for (size_t i = 0; i < tasks.size(); i++) {
+ server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
- // OAI-compat
- task.params.oaicompat = OAICOMPAT_TYPE_NONE;
+ if (result->is_error()) {
+ // Clean up and throw error
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
- tasks.push_back(task);
+ results.push_back(std::move(result));
+ }
- ctx_server->queue_results.add_waiting_tasks(tasks);
- ctx_server->queue_tasks.post(tasks);
+ // Format the response
+ response["type"] = "oai_chat";
+ response["streaming"] = false;
+ response["completion_id"] = completion_id;
+
+ if (results.size() == 1) {
+ // Single result
+ response["result"] = results[0]->to_json();
+ } else {
+ // Multiple results
+ json results_array = json::array();
+ for (auto& res: results) {
+ results_array.push_back(res->to_json());
+ }
+ response["results"] = results_array;
+ }
- std::unordered_set task_ids = server_task::get_list_id(tasks);
- const auto id_task = *task_ids.begin();
- json responses = json::array();
+ // Clean up
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ } else {
+ // For streaming, return the task IDs
+ response["type"] = "oai_chat_stream_init";
+ response["streaming"] = true;
+ response["completion_id"] = completion_id;
+
+ // Convert set to array
+ json task_ids_array = json::array();
+ for (const auto& id: task_ids) {
+ task_ids_array.push_back(id);
+ }
+ response["task_ids"] = task_ids_array;
- json error = nullptr;
+ SRV_INF("Started streaming OAI chat completion with %zu task(s)\n", task_ids.size());
+ }
- server_task_result_ptr result = ctx_server->queue_results.recv(id_task);
+ // Return the response as a JSON string
+ std::string response_str = response.dump();
+ jstring result = env->NewStringUTF(response_str.c_str());
- json response_str = result->to_json();
- if (result->is_error()) {
- std::string response = result->to_json()["message"].get();
- ctx_server->queue_results.remove_waiting_task_id(id_task);
- env->ThrowNew(c_llama_error, response.c_str());
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleChatCompletions: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
return nullptr;
}
+}
- if (result->is_stop()) {
- ctx_server->queue_results.remove_waiting_task_id(id_task);
- }
+/**
+ * Handle text infill request (completing text with given prefix and suffix).
+ * Equivalent to POST /infill endpoint.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleInfill(JNIEnv* env, jobject obj, jstring jrequestData, jboolean jstream) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
- const auto out_res = result->to_json();
+ auto* ctx_server = reinterpret_cast(server_handle);
- // Extract "embedding" as a vector of vectors (2D array)
- std::vector> embedding = out_res["embedding"].get>>();
+ // Check if embeddings mode is active (which would prevent infill)
+ if (ctx_server->params_base.embedding) {
+ env->ThrowNew(c_llama_error, "This server does not support infill. Start it without `--embeddings`");
+ return nullptr;
+ }
- // Get total number of rows in the embedding
- jsize embedding_rows = embedding.size();
+ // Check model compatibility for infill
+ std::string err;
+ if (llama_vocab_fim_pre(ctx_server->vocab) == LLAMA_TOKEN_NULL) {
+ err += "prefix token is missing. ";
+ }
+ if (llama_vocab_fim_suf(ctx_server->vocab) == LLAMA_TOKEN_NULL) {
+ err += "suffix token is missing. ";
+ }
+ if (llama_vocab_fim_mid(ctx_server->vocab) == LLAMA_TOKEN_NULL) {
+ err += "middle token is missing. ";
+ }
+ if (!err.empty()) {
+ env->ThrowNew(c_llama_error, ("Infill is not supported by this model: " + err).c_str());
+ return nullptr;
+ }
- // Get total number of columns in the first row (assuming all rows are of equal length)
- jsize embedding_cols = embedding_rows > 0 ? embedding[0].size() : 0;
+ // Parse request data from JSON
+ std::string request_str = parse_jstring(env, jrequestData);
+ json data = json::parse(request_str);
- SRV_INF("Embedding has %d rows and %d columns\n", embedding_rows, embedding_cols);
+ // Validate input
+ if (data.contains("prompt") && !data.at("prompt").is_string()) {
+ env->ThrowNew(c_llama_error, "\"prompt\" must be a string");
+ return nullptr;
+ }
- // Ensure embedding is not empty
- if (embedding.empty() || embedding[0].empty()) {
- env->ThrowNew(c_error_oom, "embedding array is empty");
- return nullptr;
- }
+ if (!data.contains("input_prefix")) {
+ env->ThrowNew(c_llama_error, "\"input_prefix\" is required");
+ return nullptr;
+ }
- // Extract only the first row
- const std::vector &first_row = embedding[0]; // Reference to avoid copying
+ if (!data.contains("input_suffix")) {
+ env->ThrowNew(c_llama_error, "\"input_suffix\" is required");
+ return nullptr;
+ }
- // Create a new float array in JNI
- jfloatArray j_embedding = env->NewFloatArray(embedding_cols);
- if (j_embedding == nullptr) {
- env->ThrowNew(c_error_oom, "could not allocate embedding");
- return nullptr;
- }
+ if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
+ env->ThrowNew(c_llama_error, "\"input_extra\" must be an array of {\"filename\": string, \"text\": string}");
+ return nullptr;
+ }
- // Copy the first row into the JNI float array
- env->SetFloatArrayRegion(j_embedding, 0, embedding_cols, reinterpret_cast(first_row.data()));
+ // Set streaming flag
+ bool stream = jstream;
+ data["stream"] = stream;
- return j_embedding;
-}
+ // Process input_extra (context chunks)
+ json input_extra = json_value(data, "input_extra", json::array());
+ for (const auto& chunk : input_extra) {
+ if (!chunk.contains("text") || !chunk.at("text").is_string()) {
+ env->ThrowNew(c_llama_error, "extra_context chunk must contain a \"text\" field with a string value");
+ return nullptr;
+ }
+ if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
+ env->ThrowNew(c_llama_error, "extra_context chunk's \"filename\" field must be a string");
+ return nullptr;
+ }
+ }
+ data["input_extra"] = input_extra;
+
+ // Format the infill prompt
+ std::string prompt = json_value(data, "prompt", std::string());
+ std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, false, true);
+
+ data["prompt"] = format_infill(
+ ctx_server->vocab,
+ data.at("input_prefix"),
+ data.at("input_suffix"),
+ data.at("input_extra"),
+ ctx_server->params_base.n_batch,
+ ctx_server->params_base.n_predict,
+ ctx_server->slots[0].n_ctx,
+ ctx_server->params_base.spm_infill,
+ tokenized_prompts.empty() ? std::vector() : tokenized_prompts[0]
+ );
+
+ // Create a completion ID
+ auto completion_id = gen_chatcmplid();
+ std::vector tasks;
+
+ try {
+ // Process formatted prompt
+ std::vector infill_prompts = tokenize_input_prompts(
+ ctx_server->vocab, data.at("prompt"), true, true);
+
+ tasks.reserve(infill_prompts.size());
+ for (size_t i = 0; i < infill_prompts.size(); i++) {
+ server_task task(SERVER_TASK_TYPE_INFILL);
+
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.index = i;
+
+ task.prompt_tokens = std::move(infill_prompts[i]);
+ task.params = server_task::params_from_json_cmpl(
+ ctx_server->ctx, ctx_server->params_base, data);
+
+ task.id_selected_slot = json_value(data, "id_slot", -1);
+
+ // Infill is not OAI compatible, but we still set the completion ID
+ task.params.oaicompat = OAICOMPAT_TYPE_NONE;
+ task.params.oaicompat_cmpl_id = completion_id;
+
+ tasks.push_back(task);
+ }
+ } catch (const std::exception& e) {
+ const auto& err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST);
+ env->ThrowNew(c_llama_error, err.dump().c_str());
+ return nullptr;
+ }
-JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt,
- jobjectArray documents) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
+ // Add tasks to waiting queue and post them for processing
+ ctx_server->queue_results.add_waiting_tasks(tasks);
+ ctx_server->queue_tasks.post(tasks);
- if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) {
- env->ThrowNew(c_llama_error,
- "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
- return nullptr;
- }
+ // Get task IDs
+ const auto task_ids = server_task::get_list_id(tasks);
- const std::string prompt = parse_jstring(env, jprompt);
+ // Create response JSON
+ json response;
- const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true);
+ if (!stream) {
+ // For non-streaming, collect all results
+ std::vector results;
+ results.reserve(tasks.size());
- json responses = json::array();
+ for (size_t i = 0; i < tasks.size(); i++) {
+ server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
- std::vector tasks;
- const jsize amount_documents = env->GetArrayLength(documents);
- auto *document_array = parse_string_array(env, documents, amount_documents);
- auto document_vector = std::vector(document_array, document_array + amount_documents);
- free_string_array(document_array, amount_documents);
+ if (result->is_error()) {
+ // Clean up and throw error
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
- std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true);
+ results.push_back(std::move(result));
+ }
- tasks.reserve(tokenized_docs.size());
- for (int i = 0; i < tokenized_docs.size(); i++) {
- auto task = server_task(SERVER_TASK_TYPE_RERANK);
- task.id = ctx_server->queue_tasks.get_new_id();
- task.index = i;
- task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
- tasks.push_back(task);
- }
- ctx_server->queue_results.add_waiting_tasks(tasks);
- ctx_server->queue_tasks.post(tasks);
+ // Format the response
+ response["type"] = "infill";
+ response["streaming"] = false;
+ response["completion_id"] = completion_id;
+
+ if (results.size() == 1) {
+ // Single result
+ response["result"] = results[0]->to_json();
+ } else {
+ // Multiple results
+ json results_array = json::array();
+ for (auto& res : results) {
+ results_array.push_back(res->to_json());
+ }
+ response["results"] = results_array;
+ }
- // get the result
- std::unordered_set task_ids = server_task::get_list_id(tasks);
- std::vector results(task_ids.size());
+ // Clean up
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ } else {
+ // For streaming, return the task IDs
+ response["type"] = "infill_stream_init";
+ response["streaming"] = true;
+ response["completion_id"] = completion_id;
+
+ // Convert set to array
+ json task_ids_array = json::array();
+ for (const auto& id : task_ids) {
+ task_ids_array.push_back(id);
+ }
+ response["task_ids"] = task_ids_array;
+
+ SRV_INF("Started streaming infill with %zu task(s)\n", task_ids.size());
+ }
- // Create a new HashMap instance
- jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
- if (o_probabilities == nullptr) {
- env->ThrowNew(c_llama_error, "Failed to create HashMap object.");
+ // Return the response as a JSON string
+ std::string response_str = response.dump();
+ jstring result = env->NewStringUTF(response_str.c_str());
+
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleInfill: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
return nullptr;
}
+}
+
+/**
+ * Get the next chunk of streaming results for a completion task.
+ * Used to retrieve results during streaming.
+ */
+
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_getNextStreamResult(JNIEnv* env, jobject obj, jint taskId) {
+ auto* ctx_server = static_cast(nullptr);
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
+
+ ctx_server = reinterpret_cast(server_handle);
+
+ // Get next result chunk from the result queue
+ server_task_result_ptr result = ctx_server->queue_results.recv(taskId);
- for (int i = 0; i < (int)task_ids.size(); i++) {
- server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
if (result->is_error()) {
- auto response = result->to_json()["message"].get();
- for (const int id_task : task_ids) {
- ctx_server->queue_results.remove_waiting_task_id(id_task);
- }
- env->ThrowNew(c_llama_error, response.c_str());
+ // If there's an error, clean up and throw
+ ctx_server->queue_results.remove_waiting_task_id(taskId);
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
return nullptr;
}
+
+ // Try to parse the result JSON (check for UTF-8 validity)
+ json resultJson;
+ try {
+ resultJson = result->to_json();
+ } catch (const json::exception& e) {
+ // If parsing fails, create a basic error response instead
+ SRV_WRN("JSON parsing error: %s\n", e.what());
+ resultJson = {
+ {"content", "[Content contains invalid characters]"},
+ {"error", e.what()}
+ };
+ }
- const auto out_res = result->to_json();
+ // Create response JSON with metadata
+ json response = {
+ {"type", "stream_chunk"},
+ {"task_id", taskId},
+ {"result", resultJson},
+ {"is_final", result->is_stop()}
+ };
+ // If this is the final result, remove the task from the queue
if (result->is_stop()) {
- for (const int id_task : task_ids) {
- ctx_server->queue_results.remove_waiting_task_id(id_task);
+ ctx_server->queue_results.remove_waiting_task_id(taskId);
+ }
+
+ // Create JSON string with extra safety measures
+ std::string response_str;
+ try {
+ response_str = response.dump();
+
+ // Verify JSON is parseable (double-check)
+ json::parse(response_str);
+ } catch (const json::exception& e) {
+ // If still failing, create a minimal valid JSON response
+ SRV_ERR("Failed to create valid JSON response: %s\n", e.what());
+ json fallback = {
+ {"type", "stream_chunk"},
+ {"task_id", taskId},
+ {"result", {{"content", "[INVALID CONTENT]"}}},
+ {"is_final", result->is_stop()},
+ {"error", "Failed to generate valid JSON"}
+ };
+ response_str = fallback.dump();
+ }
+
+ // Check for invalid UTF-8 characters
+ if (!is_valid_utf8(response_str)) {
+ SRV_WRN("Response contains invalid UTF-8, sanitizing\n", "");
+ response_str = sanitize_utf8(response_str);
+ }
+
+ // Create Java string
+ jstring result_str = env->NewStringUTF(response_str.c_str());
+
+ // Check if string creation succeeded
+ if (result_str == nullptr) {
+ // If NewStringUTF failed (due to invalid UTF-8), create a fallback response
+ SRV_ERR("Failed to create Java string from response\n","");
+
+ // Create a minimal ASCII-only response
+ json ascii_fallback = {
+ {"type", "stream_chunk"},
+ {"task_id", taskId},
+ {"result", {{"content", "[CONTENT CONTAINS INVALID CHARACTERS]"}}},
+ {"is_final", result->is_stop()},
+ {"error", "Invalid UTF-8 characters in response"}
+ };
+
+ // Use the ASCII-only fallback
+ result_str = env->NewStringUTF(ascii_fallback.dump().c_str());
+
+ // If still failing, something is very wrong
+ if (result_str == nullptr) {
+ env->ThrowNew(c_llama_error, "Critical error: Unable to create response string");
+ return nullptr;
}
}
+
+ return result_str;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in getNextStreamResult: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ if (ctx_server != nullptr) {
+ ctx_server->queue_results.remove_waiting_task_id(taskId);
+ }
+ return nullptr;
+ }
+}
+
+/**
+ * Release resources associated with a task.
+ * Used to clean up after a task is complete.
+ */
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv* env, jobject obj, jint taskId) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return;
+ }
- int index = out_res["index"].get();
- float score = out_res["score"].get();
- std::string tok_str = document_vector[index];
- jstring jtok_str = env->NewStringUTF(tok_str.c_str());
+ auto* ctx_server = reinterpret_cast(server_handle);
- jobject jprob = env->NewObject(c_float, cc_float, score);
- env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob);
- env->DeleteLocalRef(jtok_str);
- env->DeleteLocalRef(jprob);
+ // Remove the task from the waiting tasks queue
+ ctx_server->queue_results.remove_waiting_task_id(taskId);
+
+ SRV_INF("Task %d released\n", taskId);
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in releaseTask: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
}
- jbyteArray jbytes = parse_jbytes(env, prompt);
- return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true);
}
-JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
+/**
+ * Cancel an ongoing completion.
+ * Stops generation and cleans up resources.
+ */
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv* env, jobject obj, jint taskId) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return;
+ }
- std::string c_params = parse_jstring(env, jparams);
- json data = json::parse(c_params);
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Create a set with the task ID
+ std::unordered_set task_ids = {taskId};
+
+ // Cancel the tasks in the server context
+ ctx_server->cancel_tasks(task_ids);
+
+ // Remove the task from the waiting tasks queue
+ ctx_server->queue_results.remove_waiting_task_id(taskId);
+
+ SRV_INF("Task %d canceled\n", taskId);
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in cancelCompletion: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ }
+}
- json templateData =
- oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja,
- ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get());
- std::string tok_str = templateData.at("prompt");
- jstring jtok_str = env->NewStringUTF(tok_str.c_str());
- return jtok_str;
-}
+/**
+ * Handle embeddings request.
+ * Equivalent to POST /embeddings endpoint.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEnv* env, jobject obj, jstring jrequestData, jboolean joaiCompat) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
-JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Check if embeddings mode is enabled
+ if (!ctx_server->params_base.embedding) {
+ env->ThrowNew(c_llama_error, "Model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))");
+ return nullptr;
+ }
- const std::string c_prompt = parse_jstring(env, jprompt);
+ // Set compatibility mode
+ oaicompat_type oaicompat = joaiCompat ? OAICOMPAT_TYPE_EMBEDDING : OAICOMPAT_TYPE_NONE;
+
+ // Check if pooling type is compatible with OAI mode
+ if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server->ctx) == LLAMA_POOLING_TYPE_NONE) {
+ env->ThrowNew(c_llama_error, "Pooling type 'none' is not OAI compatible. Please use a different pooling type");
+ return nullptr;
+ }
+
+ // Parse request data from JSON
+ std::string request_str = parse_jstring(env, jrequestData);
+ json body = json::parse(request_str);
+
+ // Check for input field
+ json prompt;
+ if (body.count("input") != 0) {
+ prompt = body.at("input");
+ } else if (body.contains("content")) {
+ // "content" field is not OAI compatible
+ oaicompat = OAICOMPAT_TYPE_NONE;
+ prompt = body.at("content");
+ } else {
+ env->ThrowNew(c_llama_error, "\"input\" or \"content\" must be provided");
+ return nullptr;
+ }
+
+ // Check encoding format
+ bool use_base64 = false;
+ if (body.count("encoding_format") != 0) {
+ const std::string& format = body.at("encoding_format");
+ if (format == "base64") {
+ use_base64 = true;
+ } else if (format != "float") {
+ env->ThrowNew(c_llama_error, "The format to return the embeddings in. Can be either float or base64");
+ return nullptr;
+ }
+ }
+
+ // Tokenize the prompts
+ std::vector tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true);
+
+ // Check for empty input
+ for (const auto& tokens : tokenized_prompts) {
+ if (tokens.empty()) {
+ env->ThrowNew(c_llama_error, "Input content cannot be empty");
+ return nullptr;
+ }
+ }
+
+ // Create embedding tasks
+ json responses = json::array();
+ std::vector tasks;
+ tasks.reserve(tokenized_prompts.size());
+
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
+ server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
- llama_tokens tokens = tokenize_mixed(ctx_server->vocab, c_prompt, false, true);
- jsize token_size = tokens.size(); // NOLINT(*-narrowing-conversions)
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.index = i;
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
+ task.params.oaicompat = oaicompat;
- jintArray java_tokens = env->NewIntArray(token_size);
- if (java_tokens == nullptr) {
- env->ThrowNew(c_error_oom, "could not allocate token memory");
+ tasks.push_back(task);
+ }
+
+ // Submit tasks for processing
+ ctx_server->queue_results.add_waiting_tasks(tasks);
+ ctx_server->queue_tasks.post(tasks);
+
+ // Get task IDs
+ std::unordered_set task_ids = server_task::get_list_id(tasks);
+
+ // Get task results
+ for (size_t i = 0; i < tasks.size(); i++) {
+ server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
+
+ if (result->is_error()) {
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ responses.push_back(result->to_json());
+ }
+
+ // Clean up
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+
+ // Format response based on compatibility mode
+ json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
+ ? format_embeddings_response_oaicompat(body, responses, use_base64)
+ : json(responses);
+
+ // Return the response as a JSON string
+ std::string response_str = root.dump(2);
+ jstring result = env->NewStringUTF(response_str.c_str());
+
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleEmbeddings: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
return nullptr;
}
+}
- env->SetIntArrayRegion(java_tokens, 0, token_size, reinterpret_cast(tokens.data()));
+/**
+ * Handle reranking request.
+ * Equivalent to POST /rerank endpoint.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleRerank(JNIEnv* env, jobject obj, jstring jrequestData) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
- return java_tokens;
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Check if reranking mode is enabled and embedding mode is disabled
+ if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) {
+ env->ThrowNew(c_llama_error,
+ "This server does not support reranking. Start it with `--reranking` and without `--embedding`");
+ return nullptr;
+ }
+
+ // Parse request data from JSON
+ std::string request_str = parse_jstring(env, jrequestData);
+ json body = json::parse(request_str);
+
+ // Check if using TEI or Jina API format
+ bool is_tei_format = body.contains("texts");
+
+ // Validate and get query
+ json query;
+ if (body.count("query") == 1) {
+ query = body.at("query");
+ if (!query.is_string()) {
+ env->ThrowNew(c_llama_error, "\"query\" must be a string");
+ return nullptr;
+ }
+ } else {
+ env->ThrowNew(c_llama_error, "\"query\" must be provided");
+ return nullptr;
+ }
+
+ // Get documents/texts
+ std::vector documents = json_value(body, "documents",
+ json_value(body, "texts", std::vector()));
+ if (documents.empty()) {
+ env->ThrowNew(c_llama_error, "\"documents\" must be a non-empty string array");
+ return nullptr;
+ }
+
+ // Tokenize query
+ llama_tokens tokenized_query = tokenize_input_prompts(ctx_server->vocab, query, false, true)[0];
+
+ // Create rerank tasks
+ json responses = json::array();
+ std::vector tasks;
+ std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, documents, false, true);
+
+ tasks.reserve(tokenized_docs.size());
+ for (size_t i = 0; i < tokenized_docs.size(); i++) {
+ server_task task = server_task(SERVER_TASK_TYPE_RERANK);
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.index = i;
+ task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
+ tasks.push_back(task);
+ }
+
+ // Submit tasks for processing
+ ctx_server->queue_results.add_waiting_tasks(tasks);
+ ctx_server->queue_tasks.post(tasks);
+
+ // Get task IDs
+ std::unordered_set task_ids = server_task::get_list_id(tasks);
+
+ // Get task results
+ for (size_t i = 0; i < tasks.size(); i++) {
+ server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
+
+ if (result->is_error()) {
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ responses.push_back(result->to_json());
+ }
+
+ // Clean up
+ ctx_server->queue_results.remove_waiting_task_ids(task_ids);
+
+ // Format the rerank response
+ json root = format_response_rerank(
+ body,
+ responses,
+ is_tei_format,
+ documents);
+
+ // Return the response as a JSON string
+ std::string response_str = root.dump(2);
+ jstring result = env->NewStringUTF(response_str.c_str());
+
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleRerank: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ return nullptr;
+ }
}
-JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv *env, jobject obj,
- jintArray java_tokens) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
- jsize length = env->GetArrayLength(java_tokens);
- jint *elements = env->GetIntArrayElements(java_tokens, nullptr);
- std::vector tokens(elements, elements + length);
- std::string text = tokens_to_str(ctx_server->ctx, tokens.cbegin(), tokens.cend());
+/**
+ * Handle tokenization request.
+ * Equivalent to POST /tokenize endpoint.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleTokenize(JNIEnv* env, jobject obj, jstring jcontent, jboolean jaddSpecial, jboolean jwithPieces) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
+
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Parse parameters
+ const std::string content = parse_jstring(env, jcontent);
+ const bool add_special = jaddSpecial;
+ const bool with_pieces = jwithPieces;
+
+ // Tokenize the content
+ llama_tokens tokens = tokenize_mixed(ctx_server->vocab, content, add_special, true);
+
+ // Create response JSON
+ json tokens_response = json::array();
+
+ if (with_pieces) {
+ // If detailed token info is requested, include token pieces
+ for (const auto& token : tokens) {
+ std::string piece = common_token_to_piece(ctx_server->ctx, token);
+ json piece_json;
+
+ // Check if the piece is valid UTF-8
+ if (is_valid_utf8(piece)) {
+ piece_json = piece;
+ } else {
+ // If not valid UTF-8, store as array of byte values
+ piece_json = json::array();
+ for (unsigned char c : piece) {
+ piece_json.push_back(static_cast(c));
+ }
+ }
+
+ tokens_response.push_back({
+ {"id", token},
+ {"piece", piece_json}
+ });
+ }
+ } else {
+ // Otherwise just include token IDs
+ tokens_response = tokens;
+ }
+
+ // Format the response
+ json data = format_tokenizer_response(tokens_response);
+
+ // Return as JSON string
+ std::string response_str = data.dump(2);
+ jstring result = env->NewStringUTF(response_str.c_str());
+
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleTokenize: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ return nullptr;
+ }
+}
- env->ReleaseIntArrayElements(java_tokens, elements, 0);
+/**
+ * Handle detokenization request.
+ * Equivalent to POST /detokenize endpoint.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleDetokenize(JNIEnv* env, jobject obj, jintArray jtokens) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
- return parse_jbytes(env, text);
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Convert Java tokens to C++ vector
+ jsize length = env->GetArrayLength(jtokens);
+ jint* elements = env->GetIntArrayElements(jtokens, nullptr);
+ std::vector tokens(elements, elements + length);
+
+ // Convert tokens to string
+ std::string content = tokens_to_str(ctx_server->ctx, tokens.cbegin(), tokens.cend());
+
+ // Release Java array elements
+ env->ReleaseIntArrayElements(jtokens, elements, JNI_ABORT);
+
+ // Format the response
+ json data = format_detokenized_response(content);
+
+ // Return as JSON string
+ std::string response_str = data.dump(2);
+ jstring result = env->NewStringUTF(response_str.c_str());
+
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleDetokenize: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ return nullptr;
+ }
}
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *env, jobject obj) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
- ctx_server->queue_tasks.terminate();
- // delete ctx_server;
+/**
+ * Apply chat template to messages.
+ * Equivalent to POST /apply-template endpoint.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv* env, jobject obj, jstring jrequestData) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
+
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Parse request data
+ std::string request_str = parse_jstring(env, jrequestData);
+ json body = json::parse(request_str);
+
+ // Apply the template using the OpenAI parameter parsing function
+ // This function processes the messages using the model's chat template
+ json templateData = oaicompat_completion_params_parse(
+ body,
+ ctx_server->params_base.use_jinja,
+ ctx_server->params_base.reasoning_format,
+ ctx_server->chat_templates.get()
+ );
+
+ // Extract the formatted prompt
+ std::string formatted_prompt = templateData.at("prompt");
+
+ // Create response JSON
+ json response = {
+ {"prompt", formatted_prompt}
+ };
+
+ // Return as JSON string
+ std::string response_str = response.dump(2);
+ jstring result = env->NewStringUTF(response_str.c_str());
+
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in applyTemplate: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ return nullptr;
+ }
}
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *env, jobject obj, jint id_task) {
- jlong server_handle = env->GetLongField(obj, f_model_pointer);
- auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr)
- std::unordered_set id_tasks = {id_task};
- ctx_server->cancel_tasks(id_tasks);
- ctx_server->queue_results.remove_waiting_task_id(id_task);
+/**
+ * Handle slot management operations.
+ * Consolidates GET /slots and POST /slots/:id_slot endpoints.
+ *
+ * @param env JNI environment
+ * @param obj Java object
+ * @param action Action to perform: 0=GET (list), 1=SAVE, 2=RESTORE, 3=ERASE
+ * @param slotId Slot ID (ignored for GET action)
+ * @param jfilename Filename for save/restore (ignored for GET and ERASE actions)
+ * @return JSON string for GET action, true/false for other actions
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleSlotAction(JNIEnv* env, jobject obj, jint action, jint slotId, jstring jfilename) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
+
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Process based on action type
+ switch (action) {
+ case 0: { // GET - List slots
+ // Check if slots endpoint is enabled
+ if (!ctx_server->params_base.endpoint_slots) {
+ env->ThrowNew(c_llama_error, "This server does not support slots endpoint. Start it with `--slots`");
+ return nullptr;
+ }
+
+ // Request slots data using task queue
+ server_task task(SERVER_TASK_TYPE_METRICS);
+ task.id = ctx_server->queue_tasks.get_new_id();
+ ctx_server->queue_results.add_waiting_task_id(task.id);
+ ctx_server->queue_tasks.post(task, true); // high-priority task
+
+ // Get the result
+ server_task_result_ptr result = ctx_server->queue_results.recv(task.id);
+ ctx_server->queue_results.remove_waiting_task_id(task.id);
+
+ if (result->is_error()) {
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ // Parse metrics result
+ auto res_metrics = dynamic_cast(result.get());
+ if (res_metrics == nullptr) {
+ env->ThrowNew(c_llama_error, "Invalid metrics result");
+ return nullptr;
+ }
+
+ // Create JSON response with slots data
+ json response = {
+ {"slots", res_metrics->slots_data},
+ {"n_idle_slots", res_metrics->n_idle_slots},
+ {"success", true}
+ };
+
+ // Return as JSON string
+ std::string response_str = response.dump(2);
+ return env->NewStringUTF(response_str.c_str());
+ }
+
+ case 1: { // SAVE - Save slot state
+ // Check if slot save is enabled
+ if (ctx_server->params_base.slot_save_path.empty()) {
+ env->ThrowNew(c_llama_error, "This server does not support slot save. Start it with `--slot-save-path`");
+ return nullptr;
+ }
+
+ // Get the filename
+ std::string filename = parse_jstring(env, jfilename);
+ if (!fs_validate_filename(filename)) {
+ env->ThrowNew(c_llama_error, "Invalid filename");
+ return nullptr;
+ }
+
+ std::string filepath = ctx_server->params_base.slot_save_path + filename;
+
+ // Create the save task
+ server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.slot_action.slot_id = slotId;
+ task.slot_action.filename = filename;
+ task.slot_action.filepath = filepath;
+
+ ctx_server->queue_results.add_waiting_task_id(task.id);
+ ctx_server->queue_tasks.post(task);
+
+ server_task_result_ptr result = ctx_server->queue_results.recv(task.id);
+ ctx_server->queue_results.remove_waiting_task_id(task.id);
+
+ if (result->is_error()) {
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ // Create JSON response indicating success
+ json response = {
+ {"action", "save"},
+ {"slot_id", slotId},
+ {"filename", filename},
+ {"success", true}
+ };
+
+ SRV_INF("Slot %d saved to file %s\n", slotId, filename.c_str());
+
+ // Return as JSON string
+ std::string response_str = response.dump(2);
+ return env->NewStringUTF(response_str.c_str());
+ }
+
+ case 2: { // RESTORE - Restore slot state
+ // Check if slot save is enabled
+ if (ctx_server->params_base.slot_save_path.empty()) {
+ env->ThrowNew(c_llama_error, "This server does not support slot restore. Start it with `--slot-save-path`");
+ return nullptr;
+ }
+
+ // Get the filename
+ std::string filename = parse_jstring(env, jfilename);
+ if (!fs_validate_filename(filename)) {
+ env->ThrowNew(c_llama_error, "Invalid filename");
+ return nullptr;
+ }
+
+ std::string filepath = ctx_server->params_base.slot_save_path + filename;
+
+ // Create the restore task
+ server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.slot_action.slot_id = slotId;
+ task.slot_action.filename = filename;
+ task.slot_action.filepath = filepath;
+
+ ctx_server->queue_results.add_waiting_task_id(task.id);
+ ctx_server->queue_tasks.post(task);
+
+ server_task_result_ptr result = ctx_server->queue_results.recv(task.id);
+ ctx_server->queue_results.remove_waiting_task_id(task.id);
+
+ if (result->is_error()) {
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ // Create JSON response indicating success
+ json response = {
+ {"action", "restore"},
+ {"slot_id", slotId},
+ {"filename", filename},
+ {"success", true}
+ };
+
+ SRV_INF("Slot %d restored from file %s\n", slotId, filename.c_str());
+
+ // Return as JSON string
+ std::string response_str = response.dump(2);
+ return env->NewStringUTF(response_str.c_str());
+ }
+
+ case 3: { // ERASE - Erase slot state
+ // Create the erase task
+ server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.slot_action.slot_id = slotId;
+
+ ctx_server->queue_results.add_waiting_task_id(task.id);
+ ctx_server->queue_tasks.post(task);
+
+ server_task_result_ptr result = ctx_server->queue_results.recv(task.id);
+ ctx_server->queue_results.remove_waiting_task_id(task.id);
+
+ if (result->is_error()) {
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ // Create JSON response indicating success
+ json response = {
+ {"action", "erase"},
+ {"slot_id", slotId},
+ {"success", true}
+ };
+
+ SRV_INF("Slot %d erased\n", slotId);
+
+ // Return as JSON string
+ std::string response_str = response.dump(2);
+ return env->NewStringUTF(response_str.c_str());
+ }
+
+ default:
+ env->ThrowNew(c_llama_error, "Invalid slot action");
+ return nullptr;
+ }
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleSlotAction: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ return nullptr;
+ }
}
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *env, jclass clazz, jobject log_format,
- jobject jcallback) {
- if (o_log_callback != nullptr) {
- env->DeleteGlobalRef(o_log_callback);
+/**
+ * Convert a JSON schema to a grammar.
+ * Utility method for generating grammar rules from JSON schema definitions.
+ */
+JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv* env, jclass clazz, jstring j_schema) {
+ try {
+ // Parse the JSON schema string
+ const std::string c_schema = parse_jstring(env, j_schema);
+
+ // Parse the schema as ordered JSON (to maintain property order)
+ nlohmann::ordered_json c_schema_json;
+ try {
+ c_schema_json = nlohmann::ordered_json::parse(c_schema);
+ } catch (const nlohmann::json::exception& e) {
+ env->ThrowNew(c_llama_error, ("Failed to parse JSON schema: " + std::string(e.what())).c_str());
+ return nullptr;
+ }
+
+ // Convert JSON schema to grammar
+ std::string c_grammar;
+ try {
+ c_grammar = json_schema_to_grammar(c_schema_json);
+ } catch (const std::exception& e) {
+ env->ThrowNew(c_llama_error, ("Failed to convert schema to grammar: " + std::string(e.what())).c_str());
+ return nullptr;
+ }
+
+ // Convert the grammar string to a byte array
+ jbyteArray result = parse_jbytes(env, c_grammar);
+
+ SRV_INF("JSON schema converted to grammar (%zu bytes)\n", c_grammar.size());
+ return result;
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in jsonSchemaToGrammarBytes: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ return nullptr;
}
+}
- log_json = env->IsSameObject(log_format, o_log_format_json);
+JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv * env, jobject obj, jstring jprompt) {
+ jlong server_handle = env -> GetLongField(obj, f_model_pointer);
+ auto * ctx_server = reinterpret_cast < server_context * > (server_handle); // NOLINT(*-no-int-to-ptr)
- if (jcallback == nullptr) {
- log_callback = nullptr;
- llama_log_set(nullptr, nullptr);
- } else {
- o_log_callback = env->NewGlobalRef(jcallback);
- log_callback = [](enum ggml_log_level level, const char *text, void *user_data) {
- JNIEnv *env = get_jni_env();
- jstring message = env->NewStringUTF(text);
- jobject log_level = log_level_to_jobject(level);
- env->CallVoidMethod(o_log_callback, m_biconsumer_accept, log_level, message);
- env->DeleteLocalRef(message);
- };
- if (!log_json) {
- llama_log_set(log_callback_trampoline, nullptr);
+ const std::string c_prompt = parse_jstring(env, jprompt);
+
+ llama_tokens tokens = tokenize_mixed(ctx_server -> vocab, c_prompt, false, true);
+ jsize token_size = tokens.size(); // NOLINT(*-narrowing-conversions)
+
+ jintArray java_tokens = env -> NewIntArray(token_size);
+ if (java_tokens == nullptr) {
+ env -> ThrowNew(c_error_oom, "could not allocate token memory");
+ return nullptr;
+ }
+
+ env -> SetIntArrayRegion(java_tokens, 0, token_size, reinterpret_cast <
+ const jint * > (tokens.data()));
+
+ return java_tokens;
+}
+
+/**
+ * Manage KV cache operations for a specific slot.
+ * Consolidated function for KV cache info, clear, save, and load.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleKVCacheAction(JNIEnv* env, jobject obj, jint action, jint slotId, jstring jfilename) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return nullptr;
+ }
+
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Process based on action type
+ switch (action) {
+ case 0: { // INFO - Get KV cache information
+ // Create a task to get KV cache info
+ server_task task(SERVER_TASK_TYPE_METRICS); // Use metrics task to get info
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.slot_action.slot_id = slotId;
+
+ ctx_server->queue_results.add_waiting_task_id(task.id);
+ ctx_server->queue_tasks.post(task, true); // High priority
+
+ server_task_result_ptr result = ctx_server->queue_results.recv(task.id);
+ ctx_server->queue_results.remove_waiting_task_id(task.id);
+
+ if (result->is_error()) {
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ // Extract KV cache information from metrics
+ auto* metrics_result = dynamic_cast(result.get());
+ if (metrics_result == nullptr) {
+ env->ThrowNew(c_llama_error, "Invalid metrics result");
+ return nullptr;
+ }
+
+ // Create response with KV cache information
+ json kv_info = {
+ {"action", "info"},
+ {"slot_id", slotId},
+ {"kv_cache_tokens", metrics_result->kv_cache_tokens_count},
+ {"kv_cache_used_cells", metrics_result->kv_cache_used_cells},
+ {"success", true}
+ };
+
+ // Return as JSON string
+ std::string info_str = kv_info.dump(2);
+ return env->NewStringUTF(info_str.c_str());
+ }
+
+ case 1: { // CLEAR - Clear KV cache
+ // Create a task to clear KV cache
+ server_task task(SERVER_TASK_TYPE_SLOT_ERASE); // Use slot erase to clear cache
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.slot_action.slot_id = slotId;
+
+ ctx_server->queue_results.add_waiting_task_id(task.id);
+ ctx_server->queue_tasks.post(task);
+
+ server_task_result_ptr result = ctx_server->queue_results.recv(task.id);
+ ctx_server->queue_results.remove_waiting_task_id(task.id);
+
+ if (result->is_error()) {
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ // Create response indicating success
+ json clear_response = {
+ {"action", "clear"},
+ {"slot_id", slotId},
+ {"success", true}
+ };
+
+ SRV_INF("KV cache cleared for slot %d\n", slotId);
+
+ // Return as JSON string
+ std::string clear_str = clear_response.dump(2);
+ return env->NewStringUTF(clear_str.c_str());
+ }
+
+ case 2: { // SAVE - Save KV cache
+ // Check if slot save is enabled
+ if (ctx_server->params_base.slot_save_path.empty()) {
+ env->ThrowNew(c_llama_error, "This server does not support KV cache save. Start it with `--slot-save-path`");
+ return nullptr;
+ }
+
+ // Get the filename
+ std::string filename = parse_jstring(env, jfilename);
+ if (!fs_validate_filename(filename)) {
+ env->ThrowNew(c_llama_error, "Invalid filename");
+ return nullptr;
+ }
+
+ std::string filepath = ctx_server->params_base.slot_save_path + filename;
+
+ // Create the save task
+ server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.slot_action.slot_id = slotId;
+ task.slot_action.filename = filename;
+ task.slot_action.filepath = filepath;
+
+ ctx_server->queue_results.add_waiting_task_id(task.id);
+ ctx_server->queue_tasks.post(task);
+
+ server_task_result_ptr result = ctx_server->queue_results.recv(task.id);
+ ctx_server->queue_results.remove_waiting_task_id(task.id);
+
+ if (result->is_error()) {
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ // Create response indicating success
+ json save_response = {
+ {"action", "save"},
+ {"slot_id", slotId},
+ {"filename", filename},
+ {"success", true}
+ };
+
+ SRV_INF("KV cache saved for slot %d to file %s\n", slotId, filename.c_str());
+
+ // Return as JSON string
+ std::string save_str = save_response.dump(2);
+ return env->NewStringUTF(save_str.c_str());
+ }
+
+ case 3: { // LOAD - Load KV cache
+ // Check if slot save is enabled
+ if (ctx_server->params_base.slot_save_path.empty()) {
+ env->ThrowNew(c_llama_error, "This server does not support KV cache load. Start it with `--slot-save-path`");
+ return nullptr;
+ }
+
+ // Get the filename
+ std::string filename = parse_jstring(env, jfilename);
+ if (!fs_validate_filename(filename)) {
+ env->ThrowNew(c_llama_error, "Invalid filename");
+ return nullptr;
+ }
+
+ std::string filepath = ctx_server->params_base.slot_save_path + filename;
+
+ // Create the restore task
+ server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
+ task.id = ctx_server->queue_tasks.get_new_id();
+ task.slot_action.slot_id = slotId;
+ task.slot_action.filename = filename;
+ task.slot_action.filepath = filepath;
+
+ ctx_server->queue_results.add_waiting_task_id(task.id);
+ ctx_server->queue_tasks.post(task);
+
+ server_task_result_ptr result = ctx_server->queue_results.recv(task.id);
+ ctx_server->queue_results.remove_waiting_task_id(task.id);
+
+ if (result->is_error()) {
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return nullptr;
+ }
+
+ // Create response indicating success
+ json load_response = {
+ {"action", "load"},
+ {"slot_id", slotId},
+ {"filename", filename},
+ {"success", true}
+ };
+
+ SRV_INF("KV cache loaded for slot %d from file %s\n", slotId, filename.c_str());
+
+ // Return as JSON string
+ std::string load_str = load_response.dump(2);
+ return env->NewStringUTF(load_str.c_str());
+ }
+
+ default:
+ env->ThrowNew(c_llama_error, "Invalid KV cache action");
+ return nullptr;
}
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in handleKVCacheAction: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ return nullptr;
}
}
-JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *env, jclass clazz,
- jstring j_schema) {
- const std::string c_schema = parse_jstring(env, j_schema);
- nlohmann::ordered_json c_schema_json = nlohmann::ordered_json::parse(c_schema);
- const std::string c_grammar = json_schema_to_grammar(c_schema_json);
- return parse_jbytes(env, c_grammar);
+/**
+ * Configure parallel inference settings.
+ * Controls how inference tasks are distributed and executed in parallel.
+ */
+JNIEXPORT jboolean JNICALL Java_de_kherud_llama_LlamaModel_configureParallelInference(JNIEnv* env, jobject obj, jstring jconfig) {
+ try {
+ // Get server context pointer from Java object
+ jlong server_handle = env->GetLongField(obj, f_model_pointer);
+ if (server_handle == 0) {
+ env->ThrowNew(c_llama_error, "Model is not loaded");
+ return JNI_FALSE;
+ }
+
+ auto* ctx_server = reinterpret_cast(server_handle);
+
+ // Parse configuration from JSON
+ std::string config_str = parse_jstring(env, jconfig);
+ json config = json::parse(config_str);
+
+ // Store original settings for rollback in case of failure
+ int original_n_parallel = ctx_server->params_base.n_parallel;
+ float original_similarity_threshold = ctx_server->slot_prompt_similarity;
+
+ // Track changes to report
+ json changes = json::object();
+ bool changes_made = false;
+
+ if (config.contains("n_parallel")) {
+ int n_parallel = config["n_parallel"].get();
+ if (n_parallel <= 0) {
+ env->ThrowNew(c_llama_error, "n_parallel must be greater than 0");
+ return JNI_FALSE;
+ }
+
+ if (n_parallel != ctx_server->params_base.n_parallel) {
+ // Changing the number of parallel slots requires model reloading
+ // which isn't supported at runtime, so we'll throw an error
+ env->ThrowNew(c_llama_error, "Changing the number of parallel slots requires restarting the model");
+ return JNI_FALSE;
+ }
+
+ changes["n_parallel"] = n_parallel;
+ }
+
+ if (config.contains("slot_prompt_similarity")) {
+ float similarity = config["slot_prompt_similarity"].get();
+ if (similarity < 0.0f || similarity > 1.0f) {
+ env->ThrowNew(c_llama_error, "slot_prompt_similarity must be between 0.0 and 1.0");
+ return JNI_FALSE;
+ }
+
+ ctx_server->slot_prompt_similarity = similarity;
+ changes["slot_prompt_similarity"] = similarity;
+ changes_made = true;
+ }
+
+ // Check for other parameters in server context that you want to configure
+ // For example, n_threads, n_threads_batch, etc.
+ if (config.contains("n_threads")) {
+ int n_threads = config["n_threads"].get();
+ if (n_threads <= 0) {
+ env->ThrowNew(c_llama_error, "n_threads must be greater than 0");
+ return JNI_FALSE;
+ }
+
+ ctx_server->params_base.cpuparams.n_threads = n_threads;
+ changes["n_threads"] = n_threads;
+ changes_made = true;
+ }
+
+ if (config.contains("n_threads_batch")) {
+ int n_threads_batch = config["n_threads_batch"].get();
+ if (n_threads_batch <= 0) {
+ env->ThrowNew(c_llama_error, "n_threads_batch must be greater than 0");
+ return JNI_FALSE;
+ }
+
+ ctx_server->params_base.cpuparams_batch.n_threads = n_threads_batch;
+ changes["n_threads_batch"] = n_threads_batch;
+ changes_made = true;
+ }
+
+ // Since there's no dedicated task type for updating parallel config,
+ // we'll use the metrics task to ensure the changes are propagated
+ // through the server context
+ if (changes_made) {
+ // Request metrics to ensure changes are propagated
+ server_task task(SERVER_TASK_TYPE_METRICS);
+ task.id = ctx_server->queue_tasks.get_new_id();
+
+ ctx_server->queue_results.add_waiting_task_id(task.id);
+ ctx_server->queue_tasks.post(task, true); // High priority
+
+ // Wait for the result
+ server_task_result_ptr result = ctx_server->queue_results.recv(task.id);
+ ctx_server->queue_results.remove_waiting_task_id(task.id);
+
+ if (result->is_error()) {
+ // Rollback changes if there was an error
+ ctx_server->params_base.n_parallel = original_n_parallel;
+ ctx_server->slot_prompt_similarity = original_similarity_threshold;
+
+ std::string error_msg = result->to_json()["message"].get();
+ env->ThrowNew(c_llama_error, error_msg.c_str());
+ return JNI_FALSE;
+ }
+
+ // Create a success response
+ json response = {
+ {"success", true},
+ {"changes", changes}
+ };
+
+ SRV_INF("Parallel inference configuration updated: %s\n", changes.dump().c_str());
+ return JNI_TRUE;
+ } else {
+ SRV_INF("No parallel inference parameters were changed\n", " ");
+ return JNI_TRUE;
+ }
+ } catch (const std::exception& e) {
+ SRV_ERR("Exception in configureParallelInference: %s\n", e.what());
+ env->ThrowNew(c_llama_error, e.what());
+ return JNI_FALSE;
+ }
}
\ No newline at end of file
diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h
index dc17fa83..00d651bb 100644
--- a/src/main/cpp/jllama.h
+++ b/src/main/cpp/jllama.h
@@ -1,104 +1,165 @@
-/* DO NOT EDIT THIS FILE - it is machine generated */
-#include
+/* DO NOT EDIT THIS FILE - it is machine generated */ #include
+
/* Header for class de_kherud_llama_LlamaModel */
#ifndef _Included_de_kherud_llama_LlamaModel
#define _Included_de_kherud_llama_LlamaModel
#ifdef __cplusplus
extern "C" {
-#endif
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: embed
- * Signature: (Ljava/lang/String;)[F
+ #endif
+ // Core Functions
+
+/**
+ * Load a llama.cpp model with the given parameters.
+ */
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv* env, jobject obj, jobjectArray jparams);
+
+/**
+ * Clean up resources and delete the model.
*/
-JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *, jobject, jstring);
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv* env, jobject obj);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: encode
- * Signature: (Ljava/lang/String;)[I
+/**
+ * Set a logger for llama.cpp logs.
*/
-JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *, jobject, jstring);
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv* env, jclass clazz, jobject log_format, jobject jcallback);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: setLogger
- * Signature: (Lde/kherud/llama/args/LogFormat;Ljava/util/function/BiConsumer;)V
+// Server Information Endpoints
+
+/**
+ * Get server health status.
*/
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *, jclass, jobject, jobject);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_getHealth(JNIEnv* env, jobject obj);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: requestCompletion
- * Signature: (Ljava/lang/String;)I
+/**
+ * Get detailed server metrics.
*/
-JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv *, jobject, jstring);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_getMetrics(JNIEnv* env, jobject obj);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: receiveCompletion
- * Signature: (I)Lde/kherud/llama/LlamaOutput;
+/**
+ * Get model properties.
*/
-JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *, jobject, jint);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_getProps(JNIEnv* env, jobject obj);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: cancelCompletion
- * Signature: (I)V
+/**
+ * Update model properties.
*/
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *, jobject, jint);
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_updateProps(JNIEnv* env, jobject obj, jstring jprops);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: decodeBytes
- * Signature: ([I)[B
+/**
+ * Get list of available models.
*/
-JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv *, jobject, jintArray);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_getModels(JNIEnv* env, jobject obj);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: loadModel
- * Signature: ([Ljava/lang/String;)V
+/**
+ * Get current server state.
*/
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *, jobject, jobjectArray);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_getServerState(JNIEnv* env, jobject obj);
+
+// Text Generation Endpoints
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: delete
- * Signature: ()V
+/**
+ * Handle standard completions request.
*/
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *, jobject);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletions(JNIEnv* env, jobject obj, jstring jrequestData, jboolean jstream);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: releaseTask
- * Signature: (I)V
+/**
+ * Handle OpenAI compatible completions request.
*/
-JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *, jobject, jint);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletionsOai(JNIEnv* env, jobject obj, jstring jrequestData, jboolean jstream);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: jsonSchemaToGrammarBytes
- * Signature: (Ljava/lang/String;)[B
+/**
+ * Handle chat completions request.
*/
-JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *, jclass, jstring);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleChatCompletions(JNIEnv* env, jobject obj, jstring jrequestData, jboolean jstream);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: rerank
- * Signature: (Ljava/lang/String;[Ljava/lang/String;)Lde/kherud/llama/LlamaOutput;
+/**
+ * Handle text infill request.
*/
-JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *, jobject, jstring, jobjectArray);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleInfill(JNIEnv* env, jobject obj, jstring jrequestData, jboolean jstream);
-/*
- * Class: de_kherud_llama_LlamaModel
- * Method: applyTemplate
- * Signature: (Ljava/lang/String;)Ljava/lang/String;;
+/**
+ * Get next streaming result chunk.
*/
-JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *, jobject, jstring);
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_getNextStreamResult(JNIEnv* env, jobject obj, jint taskId);
-#ifdef __cplusplus
+/**
+ * Release task resources.
+ */
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv* env, jobject obj, jint taskId);
+
+/**
+ * Cancel ongoing completion.
+ */
+JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv* env, jobject obj, jint taskId);
+
+// Embeddings and Reranking Endpoints
+
+/**
+ * Handle embeddings request.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEnv* env, jobject obj, jstring jrequestData, jboolean joaiCompat);
+
+/**
+ * Handle reranking request.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleRerank(JNIEnv* env, jobject obj, jstring jrequestData);
+
+// Tokenization Endpoints
+
+/**
+ * Handle tokenization request.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleTokenize(JNIEnv* env, jobject obj, jstring jcontent, jboolean jaddSpecial, jboolean jwithPieces);
+
+/**
+ * Handle detokenization request.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleDetokenize(JNIEnv* env, jobject obj, jintArray jtokens);
+
+/**
+ * Apply chat template to messages.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv* env, jobject obj, jstring jparams);
+
+// LoRA Adapters Endpoints
+
+/**
+ * Get list of available LoRA adapters.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_getLoraAdapters(JNIEnv* env, jobject obj);
+
+/**
+ * Apply LoRA adapters to model.
+ */
+JNIEXPORT jboolean JNICALL Java_de_kherud_llama_LlamaModel_applyLoraAdapters(JNIEnv* env, jobject obj, jstring jadapters);
+
+// Slots Management Endpoints
+/**
+ * Handle slot management operations.
+ * Consolidates GET /slots and POST /slots/:id_slot endpoints.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleSlotAction(JNIEnv* env, jobject obj, jint action, jint slotId, jstring jfilename);
+
+
+// Utility Methods
+
+/**
+ * Convert JSON schema to grammar.
+ */
+JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv* env, jclass clazz, jstring j_schema);
+
+JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv * , jobject, jstring);
+
+/**
+ * Manage KV cache operations for a specific slot.
+ * Consolidated function for KV cache info, clear, save, and load.
+ */
+JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleKVCacheAction(JNIEnv* env, jobject obj, jint action, jint slotId, jstring jfilename);
+
+JNIEXPORT jboolean JNICALL Java_de_kherud_llama_LlamaModel_configureParallelInference(JNIEnv* , jobject , jstring );
+
+ #ifdef __cplusplus
}
#endif
-#endif
+#endif
\ No newline at end of file
diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp
index 66169a83..d0154ab6 100644
--- a/src/main/cpp/server.hpp
+++ b/src/main/cpp/server.hpp
@@ -31,16 +31,15 @@ enum stop_type {
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state {
SLOT_STATE_IDLE,
- SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it
- // with launch_slot_with_task in the future
+ SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING,
};
enum server_state {
- SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
- SERVER_STATE_READY, // Server is ready and model is loaded
+ SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
+ SERVER_STATE_READY, // Server is ready and model is loaded
};
enum server_task_type {
@@ -71,22 +70,21 @@ enum error_type {
ERROR_TYPE_SERVER,
ERROR_TYPE_NOT_FOUND,
ERROR_TYPE_PERMISSION,
- ERROR_TYPE_UNAVAILABLE, // custom error
+ ERROR_TYPE_UNAVAILABLE, // custom error
ERROR_TYPE_NOT_SUPPORTED, // custom error
};
struct slot_params {
- bool stream = true;
- bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
+ bool stream = true;
+ bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
- int32_t n_discard =
- 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
- int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
+ int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
- int64_t t_max_prompt_ms = -1; // TODO: implement
+ int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector lora;
@@ -101,16 +99,16 @@ struct slot_params {
struct common_params_speculative speculative;
// OAI-compat fields
- bool verbose = false;
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
- common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
json to_json() const {
std::vector samplers;
samplers.reserve(sampling.samplers.size());
- for (const auto &sampler : sampling.samplers) {
+ for (const auto & sampler : sampling.samplers) {
samplers.emplace_back(common_sampler_type_to_str(sampler));
}
@@ -120,61 +118,61 @@ struct slot_params {
}
auto grammar_triggers = json::array();
- for (const auto &trigger : sampling.grammar_triggers) {
+ for (const auto & trigger : sampling.grammar_triggers) {
grammar_triggers.push_back(trigger.to_json());
}
- return json{
- {"n_predict", n_predict}, // Server configured n_predict
- {"seed", sampling.seed},
- {"temperature", sampling.temp},
- {"dynatemp_range", sampling.dynatemp_range},
- {"dynatemp_exponent", sampling.dynatemp_exponent},
- {"top_k", sampling.top_k},
- {"top_p", sampling.top_p},
- {"min_p", sampling.min_p},
- {"xtc_probability", sampling.xtc_probability},
- {"xtc_threshold", sampling.xtc_threshold},
- {"typical_p", sampling.typ_p},
- {"repeat_last_n", sampling.penalty_last_n},
- {"repeat_penalty", sampling.penalty_repeat},
- {"presence_penalty", sampling.penalty_present},
- {"frequency_penalty", sampling.penalty_freq},
- {"dry_multiplier", sampling.dry_multiplier},
- {"dry_base", sampling.dry_base},
- {"dry_allowed_length", sampling.dry_allowed_length},
- {"dry_penalty_last_n", sampling.dry_penalty_last_n},
- {"dry_sequence_breakers", sampling.dry_sequence_breakers},
- {"mirostat", sampling.mirostat},
- {"mirostat_tau", sampling.mirostat_tau},
- {"mirostat_eta", sampling.mirostat_eta},
- {"stop", antiprompt},
- {"max_tokens", n_predict}, // User configured n_predict
- {"n_keep", n_keep},
- {"n_discard", n_discard},
- {"ignore_eos", sampling.ignore_eos},
- {"stream", stream},
- {"logit_bias", format_logit_bias(sampling.logit_bias)},
- {"n_probs", sampling.n_probs},
- {"min_keep", sampling.min_keep},
- {"grammar", sampling.grammar},
- {"grammar_lazy", sampling.grammar_lazy},
- {"grammar_triggers", grammar_triggers},
- {"preserved_tokens", sampling.preserved_tokens},
- {"chat_format", common_chat_format_name(oaicompat_chat_format)},
- {"samplers", samplers},
- {"speculative.n_max", speculative.n_max},
- {"speculative.n_min", speculative.n_min},
- {"speculative.p_min", speculative.p_min},
- {"timings_per_token", timings_per_token},
- {"post_sampling_probs", post_sampling_probs},
- {"lora", lora},
+ return json {
+ {"n_predict", n_predict}, // Server configured n_predict
+ {"seed", sampling.seed},
+ {"temperature", sampling.temp},
+ {"dynatemp_range", sampling.dynatemp_range},
+ {"dynatemp_exponent", sampling.dynatemp_exponent},
+ {"top_k", sampling.top_k},
+ {"top_p", sampling.top_p},
+ {"min_p", sampling.min_p},
+ {"xtc_probability", sampling.xtc_probability},
+ {"xtc_threshold", sampling.xtc_threshold},
+ {"typical_p", sampling.typ_p},
+ {"repeat_last_n", sampling.penalty_last_n},
+ {"repeat_penalty", sampling.penalty_repeat},
+ {"presence_penalty", sampling.penalty_present},
+ {"frequency_penalty", sampling.penalty_freq},
+ {"dry_multiplier", sampling.dry_multiplier},
+ {"dry_base", sampling.dry_base},
+ {"dry_allowed_length", sampling.dry_allowed_length},
+ {"dry_penalty_last_n", sampling.dry_penalty_last_n},
+ {"dry_sequence_breakers", sampling.dry_sequence_breakers},
+ {"mirostat", sampling.mirostat},
+ {"mirostat_tau", sampling.mirostat_tau},
+ {"mirostat_eta", sampling.mirostat_eta},
+ {"stop", antiprompt},
+ {"max_tokens", n_predict}, // User configured n_predict
+ {"n_keep", n_keep},
+ {"n_discard", n_discard},
+ {"ignore_eos", sampling.ignore_eos},
+ {"stream", stream},
+ {"logit_bias", format_logit_bias(sampling.logit_bias)},
+ {"n_probs", sampling.n_probs},
+ {"min_keep", sampling.min_keep},
+ {"grammar", sampling.grammar},
+ {"grammar_lazy", sampling.grammar_lazy},
+ {"grammar_triggers", grammar_triggers},
+ {"preserved_tokens", sampling.preserved_tokens},
+ {"chat_format", common_chat_format_name(oaicompat_chat_format)},
+ {"samplers", samplers},
+ {"speculative.n_max", speculative.n_max},
+ {"speculative.n_min", speculative.n_min},
+ {"speculative.p_min", speculative.p_min},
+ {"timings_per_token", timings_per_token},
+ {"post_sampling_probs", post_sampling_probs},
+ {"lora", lora},
};
}
};
struct server_task {
- int id = -1; // to be filled by server_queue
+ int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)
server_task_type type;
@@ -183,7 +181,7 @@ struct server_task {
int id_target = -1;
// used by SERVER_TASK_TYPE_INFERENCE
- slot_params params;
+ slot_params params;
llama_tokens prompt_tokens;
int id_selected_slot = -1;
@@ -203,61 +201,59 @@ struct server_task {
server_task(server_task_type type) : type(type) {}
- static slot_params params_from_json_cmpl(const llama_context *ctx, const common_params ¶ms_base,
- const json &data) {
- const llama_model *model = llama_get_model(ctx);
- const llama_vocab *vocab = llama_model_get_vocab(model);
+ static slot_params params_from_json_cmpl(
+ const llama_context * ctx,
+ const common_params & params_base,
+ const json & data) {
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
slot_params params;
- // Sampling parameter defaults are loaded from the global server context (but individual requests can still
- // override them)
+ // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
slot_params defaults;
- defaults.sampling = params_base.sampling;
+ defaults.sampling = params_base.sampling;
defaults.speculative = params_base.speculative;
// enabling this will output extra debug information in the HTTP responses from the server
- params.verbose = params_base.verbosity > 9;
+ params.verbose = params_base.verbosity > 9;
params.timings_per_token = json_value(data, "timings_per_token", false);
- params.stream = json_value(data, "stream", false);
- params.cache_prompt = json_value(data, "cache_prompt", true);
- params.return_tokens = json_value(data, "return_tokens", false);
- params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
- params.n_indent = json_value(data, "n_indent", defaults.n_indent);
- params.n_keep = json_value(data, "n_keep", defaults.n_keep);
- params.n_discard = json_value(data, "n_discard", defaults.n_discard);
- // params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO:
- // implement
- params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
- params.response_fields = json_value(data, "response_fields", std::vector());
-
- params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
- params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
- params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
- params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
- params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
- params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
- params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
- params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
- params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
- params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
- params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
- params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
- params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
- params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
- params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
- params.sampling.dry_allowed_length =
- json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
- params.sampling.dry_penalty_last_n =
- json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
- params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
- params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
- params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
- params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
- params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
- params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
- params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
+ params.stream = json_value(data, "stream", false);
+ params.cache_prompt = json_value(data, "cache_prompt", true);
+ params.return_tokens = json_value(data, "return_tokens", false);
+ params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
+ params.n_indent = json_value(data, "n_indent", defaults.n_indent);
+ params.n_keep = json_value(data, "n_keep", defaults.n_keep);
+ params.n_discard = json_value(data, "n_discard", defaults.n_discard);
+ //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
+ params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
+ params.response_fields = json_value(data, "response_fields", std::vector());
+
+ params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
+ params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
+ params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
+ params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
+ params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
+ params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
+ params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
+ params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
+ params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
+ params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
+ params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
+ params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
+ params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
+ params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
+ params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
+ params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
+ params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
+ params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
+ params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
+ params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
+ params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
+ params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
+ params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
+ params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
@@ -268,7 +264,7 @@ struct server_task {
params.speculative.n_max = std::max(params.speculative.n_max, 0);
// Use OpenAI API logprobs only if n_probs wasn't provided
- if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs) {
+ if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
}
@@ -308,12 +304,10 @@ struct server_task {
// sequence breakers for DRY
{
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
- // Ref:
- // https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
+ // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
if (data.contains("dry_sequence_breakers")) {
- params.sampling.dry_sequence_breakers =
- json_value(data, "dry_sequence_breakers", std::vector());
+ params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector());
if (params.sampling.dry_sequence_breakers.empty()) {
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
}
@@ -323,15 +317,15 @@ struct server_task {
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
- auto schema = json_value(data, "json_schema", json::object());
+ auto schema = json_value(data, "json_schema", json::object());
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
- params.sampling.grammar = json_schema_to_grammar(schema);
+ params.sampling.grammar = json_schema_to_grammar(schema);
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
- } catch (const std::exception &e) {
+ } catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
- params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
+ params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
@@ -350,39 +344,35 @@ struct server_task {
{
const auto preserved_tokens = data.find("preserved_tokens");
if (preserved_tokens != data.end()) {
- for (const auto &t : *preserved_tokens) {
- auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false,
- /* parse_special= */ true);
+ for (const auto & t : *preserved_tokens) {
+ auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
SRV_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
- // This may happen when using a tool call style meant for a model with special tokens to
- // preserve on a model without said tokens.
+ // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str());
}
}
}
const auto grammar_triggers = data.find("grammar_triggers");
if (grammar_triggers != data.end()) {
- for (const auto &t : *grammar_triggers) {
+ for (const auto & t : *grammar_triggers) {
auto ct = common_grammar_trigger::from_json(t);
if (ct.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
- const auto &word = ct.value;
+ const auto & word = ct.value;
auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
auto token = ids[0];
- if (std::find(params.sampling.preserved_tokens.begin(),
- params.sampling.preserved_tokens.end(),
- (llama_token)token) == params.sampling.preserved_tokens.end()) {
- throw std::runtime_error("Grammar trigger word should be marked as preserved token: " +
- word);
+ if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) {
+ throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word);
}
SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str());
common_grammar_trigger trigger;
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
- trigger.value = (llama_token)token;
- params.sampling.grammar_triggers.push_back(trigger);
+ trigger.value = word;
+ trigger.token = token;
+ params.sampling.grammar_triggers.push_back(std::move(trigger));
} else {
SRV_DBG("Grammar trigger word: `%s`\n", word.c_str());
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
@@ -401,10 +391,10 @@ struct server_task {
params.sampling.logit_bias.clear();
params.ignore_eos = json_value(data, "ignore_eos", false);
- const auto &logit_bias = data.find("logit_bias");
+ const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_vocab_n_tokens(vocab);
- for (const auto &el : *logit_bias) {
+ for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) {
float bias;
@@ -435,9 +425,9 @@ struct server_task {
{
params.antiprompt.clear();
- const auto &stop = data.find("stop");
+ const auto & stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
- for (const auto &word : *stop) {
+ for (const auto & word : *stop) {
if (!word.empty()) {
params.antiprompt.push_back(word);
}
@@ -450,7 +440,7 @@ struct server_task {
if (samplers != data.end()) {
if (samplers->is_array()) {
params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
- } else if (samplers->is_string()) {
+ } else if (samplers->is_string()){
params.sampling.samplers = common_sampler_types_from_chars(samplers->get());
}
} else {
@@ -465,7 +455,7 @@ struct server_task {
}
// utility function
- static std::unordered_set get_list_id(const std::vector &tasks) {
+ static std::unordered_set get_list_id(const std::vector & tasks) {
std::unordered_set ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
@@ -487,22 +477,22 @@ struct result_timings {
json to_json() const {
return {
- {"prompt_n", prompt_n},
- {"prompt_ms", prompt_ms},
- {"prompt_per_token_ms", prompt_per_token_ms},
- {"prompt_per_second", prompt_per_second},
+ {"prompt_n", prompt_n},
+ {"prompt_ms", prompt_ms},
+ {"prompt_per_token_ms", prompt_per_token_ms},
+ {"prompt_per_second", prompt_per_second},
- {"predicted_n", predicted_n},
- {"predicted_ms", predicted_ms},
+ {"predicted_n", predicted_n},
+ {"predicted_ms", predicted_ms},
{"predicted_per_token_ms", predicted_per_token_ms},
- {"predicted_per_second", predicted_per_second},
+ {"predicted_per_second", predicted_per_second},
};
}
};
struct server_task_result {
- int id = -1;
- int id_slot = -1;
+ int id = -1;
+ int id_slot = -1;
virtual bool is_error() {
// only used by server_task_result_error
return false;
@@ -511,7 +501,9 @@ struct server_task_result {
// only used by server_task_result_cmpl_*
return false;
}
- virtual int get_index() { return -1; }
+ virtual int get_index() {
+ return -1;
+ }
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
@@ -521,14 +513,10 @@ using server_task_result_ptr = std::unique_ptr;
inline std::string stop_type_to_str(stop_type type) {
switch (type) {
- case STOP_TYPE_EOS:
- return "eos";
- case STOP_TYPE_WORD:
- return "word";
- case STOP_TYPE_LIMIT:
- return "limit";
- default:
- return "none";
+ case STOP_TYPE_EOS: return "eos";
+ case STOP_TYPE_WORD: return "word";
+ case STOP_TYPE_LIMIT: return "limit";
+ default: return "none";
}
}
@@ -545,30 +533,39 @@ struct completion_token_output {
json to_json(bool post_sampling_probs) const {
json probs_for_token = json::array();
- for (const auto &p : probs) {
+ for (const auto & p : probs) {
std::string txt(p.txt);
txt.resize(validate_utf8(txt));
- probs_for_token.push_back(json{
- {"id", p.tok},
- {"token", txt},
- {"bytes", str_to_bytes(p.txt)},
- {post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob)},
+ probs_for_token.push_back(json {
+ {"id", p.tok},
+ {"token", txt},
+ {"bytes", str_to_bytes(p.txt)},
+ {
+ post_sampling_probs ? "prob" : "logprob",
+ post_sampling_probs ? p.prob : logarithm(p.prob)
+ },
});
}
return probs_for_token;
}
- static json probs_vector_to_json(const std::vector &probs, bool post_sampling_probs) {
+ static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) {
json out = json::array();
- for (const auto &p : probs) {
+ for (const auto & p : probs) {
std::string txt(p.text_to_send);
txt.resize(validate_utf8(txt));
- out.push_back(json{
- {"id", p.tok},
- {"token", txt},
- {"bytes", str_to_bytes(p.text_to_send)},
- {post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob)},
- {post_sampling_probs ? "top_probs" : "top_logprobs", p.to_json(post_sampling_probs)},
+ out.push_back(json {
+ {"id", p.tok},
+ {"token", txt},
+ {"bytes", str_to_bytes(p.text_to_send)},
+ {
+ post_sampling_probs ? "prob" : "logprob",
+ post_sampling_probs ? p.prob : logarithm(p.prob)
+ },
+ {
+ post_sampling_probs ? "top_probs" : "top_logprobs",
+ p.to_json(post_sampling_probs)
+ },
});
}
return out;
@@ -579,7 +576,7 @@ struct completion_token_output {
return x == 0.0f ? std::numeric_limits::lowest() : std::log(x);
}
- static std::vector str_to_bytes(const std::string &str) {
+ static std::vector str_to_bytes(const std::string & str) {
std::vector bytes;
for (unsigned char c : str) {
bytes.push_back(c);
@@ -608,18 +605,20 @@ struct server_task_result_cmpl_final : server_task_result {
bool post_sampling_probs;
std::vector probs_output;
- std::vector response_fields;
+ std::vector response_fields;
slot_params generation_params;
// OAI-compat fields
- bool verbose = false;
- oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
- common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
+ bool verbose = false;
+ oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
+ common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
- virtual int get_index() override { return index; }
+ virtual int get_index() override {
+ return index;
+ }
virtual bool is_stop() override {
return true; // in stream mode, final responses are considered stop
@@ -627,39 +626,38 @@ struct server_task_result_cmpl_final : server_task_result {
virtual json to_json() override {
switch (oaicompat) {
- case OAICOMPAT_TYPE_NONE:
- return to_json_non_oaicompat();
- case OAICOMPAT_TYPE_COMPLETION:
- return to_json_oaicompat();
- case OAICOMPAT_TYPE_CHAT:
- return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
- default:
- GGML_ASSERT(false && "Invalid oaicompat_type");
+ case OAICOMPAT_TYPE_NONE:
+ return to_json_non_oaicompat();
+ case OAICOMPAT_TYPE_COMPLETION:
+ return to_json_oaicompat();
+ case OAICOMPAT_TYPE_CHAT:
+ return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
+ default:
+ GGML_ASSERT(false && "Invalid oaicompat_type");
}
}
json to_json_non_oaicompat() {
- json res = json{
- {"index", index},
- {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
- {"tokens", stream ? llama_tokens{} : tokens},
- {"id_slot", id_slot},
- {"stop", true},
- {"model", oaicompat_model},
- {"tokens_predicted", n_decoded},
- {"tokens_evaluated", n_prompt_tokens},
+ json res = json {
+ {"index", index},
+ {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
+ {"tokens", stream ? llama_tokens {} : tokens},
+ {"id_slot", id_slot},
+ {"stop", true},
+ {"model", oaicompat_model},
+ {"tokens_predicted", n_decoded},
+ {"tokens_evaluated", n_prompt_tokens},
{"generation_settings", generation_params.to_json()},
- {"prompt", prompt},
- {"has_new_line", has_new_line},
- {"truncated", truncated},
- {"stop_type", stop_type_to_str(stop)},
- {"stopping_word", stopping_word},
- {"tokens_cached", n_tokens_cached},
- {"timings", timings.to_json()},
+ {"prompt", prompt},
+ {"has_new_line", has_new_line},
+ {"truncated", truncated},
+ {"stop_type", stop_type_to_str(stop)},
+ {"stopping_word", stopping_word},
+ {"tokens_cached", n_tokens_cached},
+ {"timings", timings.to_json()},
};
if (!stream && !probs_output.empty()) {
- res["completion_probabilities"] =
- completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
}
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
}
@@ -676,21 +674,26 @@ struct server_task_result_cmpl_final : server_task_result {
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
- json res = json{
- {"choices", json::array({json{
- {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
- {"index", index},
- {"logprobs", logprobs},
- {"finish_reason", finish_reason},
- }})},
- {"created", t},
- {"model", oaicompat_model},
+ json res = json {
+ {"choices", json::array({
+ json{
+ {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
+ {"index", index},
+ {"logprobs", logprobs},
+ {"finish_reason", finish_reason},
+ }
+ })},
+ {"created", t},
+ {"model", oaicompat_model},
{"system_fingerprint", build_info},
- {"object", "text_completion"},
- {"usage", json{{"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens}}},
- {"id", oaicompat_cmpl_id}};
+ {"object", "text_completion"},
+ {"usage", json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens}
+ }},
+ {"id", oaicompat_cmpl_id}
+ };
// extra fields for debugging purposes
if (verbose) {
@@ -714,7 +717,7 @@ struct server_task_result_cmpl_final : server_task_result {
msg.content = content;
}
- json message{
+ json message {
{"role", "assistant"},
};
if (!msg.reasoning_content.empty()) {
@@ -727,21 +730,23 @@ struct server_task_result_cmpl_final : server_task_result {
}
if (!msg.tool_calls.empty()) {
auto tool_calls = json::array();
- for (const auto &tc : msg.tool_calls) {
+ for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
- {"function",
- {
- {"name", tc.name},
- {"arguments", tc.arguments},
- }},
- {"id", tc.id},
+ {"function", {
+ {"name", tc.name},
+ {"arguments", tc.arguments},
+ }},
+ // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
+ // We only generate a random id for the ones that don't generate one by themselves
+ // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
+ {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
});
}
message["tool_calls"] = tool_calls;
}
- json choice{
+ json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", message},
@@ -755,15 +760,19 @@ struct server_task_result_cmpl_final : server_task_result {
std::time_t t = std::time(0);
- json res = json{{"choices", json::array({choice})},
- {"created", t},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "chat.completion"},
- {"usage", json{{"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens}}},
- {"id", oaicompat_cmpl_id}};
+ json res = json {
+ {"choices", json::array({choice})},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion"},
+ {"usage", json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens}
+ }},
+ {"id", oaicompat_cmpl_id}
+ };
// extra fields for debugging purposes
if (verbose) {
@@ -783,21 +792,24 @@ struct server_task_result_cmpl_final : server_task_result {
finish_reason = "stop";
}
- json choice = json{{"finish_reason", finish_reason}, {"index", 0}, {"delta", json::object()}};
+ json choice = json {
+ {"finish_reason", finish_reason},
+ {"index", 0},
+ {"delta", json::object()}
+ };
- json ret = json{
- {"choices", json::array({choice})},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
+ json ret = json {
+ {"choices", json::array({choice})},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
{"system_fingerprint", build_info},
- {"object", "chat.completion.chunk"},
- {"usage",
- json{
- {"completion_tokens", n_decoded},
- {"prompt_tokens", n_prompt_tokens},
- {"total_tokens", n_decoded + n_prompt_tokens},
- }},
+ {"object", "chat.completion.chunk"},
+ {"usage", json {
+ {"completion_tokens", n_decoded},
+ {"prompt_tokens", n_prompt_tokens},
+ {"total_tokens", n_decoded + n_prompt_tokens},
+ }},
};
if (timings.prompt_n >= 0) {
@@ -811,7 +823,7 @@ struct server_task_result_cmpl_final : server_task_result {
struct server_task_result_cmpl_partial : server_task_result {
int index = 0;
- std::string content;
+ std::string content;
llama_tokens tokens;
int32_t n_decoded;
@@ -822,12 +834,14 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings;
// OAI-compat fields
- bool verbose = false;
+ bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- std::string oaicompat_model;
- std::string oaicompat_cmpl_id;
+ std::string oaicompat_model;
+ std::string oaicompat_cmpl_id;
- virtual int get_index() override { return index; }
+ virtual int get_index() override {
+ return index;
+ }
virtual bool is_stop() override {
return false; // in stream mode, partial responses are not considered stop
@@ -835,25 +849,25 @@ struct server_task_result_cmpl_partial : server_task_result {
virtual json to_json() override {
switch (oaicompat) {
- case OAICOMPAT_TYPE_NONE:
- return to_json_non_oaicompat();
- case OAICOMPAT_TYPE_COMPLETION:
- return to_json_oaicompat();
- case OAICOMPAT_TYPE_CHAT:
- return to_json_oaicompat_chat();
- default:
- GGML_ASSERT(false && "Invalid oaicompat_type");
+ case OAICOMPAT_TYPE_NONE:
+ return to_json_non_oaicompat();
+ case OAICOMPAT_TYPE_COMPLETION:
+ return to_json_oaicompat();
+ case OAICOMPAT_TYPE_CHAT:
+ return to_json_oaicompat_chat();
+ default:
+ GGML_ASSERT(false && "Invalid oaicompat_type");
}
}
json to_json_non_oaicompat() {
// non-OAI-compat JSON
- json res = json{
- {"index", index},
- {"content", content},
- {"tokens", tokens},
- {"stop", false},
- {"id_slot", id_slot},
+ json res = json {
+ {"index", index},
+ {"content", content},
+ {"tokens", tokens},
+ {"stop", false},
+ {"id_slot", id_slot},
{"tokens_predicted", n_decoded},
{"tokens_evaluated", n_prompt_tokens},
};
@@ -862,8 +876,7 @@ struct server_task_result_cmpl_partial : server_task_result {
res.push_back({"timings", timings.to_json()});
}
if (!prob_output.probs.empty()) {
- res["completion_probabilities"] =
- completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
}
return res;
}
@@ -876,17 +889,21 @@ struct server_task_result_cmpl_partial : server_task_result {
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
- json res = json{{"choices", json::array({json{
- {"text", content},
- {"index", index},
- {"logprobs", logprobs},
- {"finish_reason", nullptr},
- }})},
- {"created", t},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "text_completion"},
- {"id", oaicompat_cmpl_id}};
+ json res = json {
+ {"choices", json::array({
+ json{
+ {"text", content},
+ {"index", index},
+ {"logprobs", logprobs},
+ {"finish_reason", nullptr},
+ }
+ })},
+ {"created", t},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "text_completion"},
+ {"id", oaicompat_cmpl_id}
+ };
// extra fields for debugging purposes
if (verbose) {
@@ -906,26 +923,32 @@ struct server_task_result_cmpl_partial : server_task_result {
if (first) {
if (content.empty()) {
- choices = json::array(
- {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"role", "assistant"}}}}});
+ choices = json::array({json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
- json initial_ret = json{{"choices", json::array({json{{"finish_reason", nullptr},
- {"index", 0},
- {"delta", json{{"role", "assistant"}}}}})},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"object", "chat.completion.chunk"}};
-
- json second_ret =
- json{{"choices",
- json::array(
- {json{{"finish_reason", nullptr}, {"index", 0}, {"delta", json{{"content", content}}}}})},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"object", "chat.completion.chunk"}};
+ json initial_ret = json{{"choices", json::array({json{
+ {"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json{
+ {"role", "assistant"}
+ }}}})},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"object", "chat.completion.chunk"}};
+
+ json second_ret = json{
+ {"choices", json::array({json{{"finish_reason", nullptr},
+ {"index", 0},
+ {"delta", json {
+ {"content", content}}}
+ }})},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"object", "chat.completion.chunk"}};
return std::vector({initial_ret, second_ret});
}
@@ -934,9 +957,9 @@ struct server_task_result_cmpl_partial : server_task_result {
{"finish_reason", nullptr},
{"index", 0},
{"delta",
- json{
- {"content", content},
- }},
+ json {
+ {"content", content},
+ }},
}});
}
@@ -948,12 +971,14 @@ struct server_task_result_cmpl_partial : server_task_result {
};
}
- json ret = json{{"choices", choices},
- {"created", t},
- {"id", oaicompat_cmpl_id},
- {"model", oaicompat_model},
- {"system_fingerprint", build_info},
- {"object", "chat.completion.chunk"}};
+ json ret = json {
+ {"choices", choices},
+ {"created", t},
+ {"id", oaicompat_cmpl_id},
+ {"model", oaicompat_model},
+ {"system_fingerprint", build_info},
+ {"object", "chat.completion.chunk"}
+ };
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
@@ -972,23 +997,27 @@ struct server_task_result_embd : server_task_result {
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
- virtual int get_index() override { return index; }
+ virtual int get_index() override {
+ return index;
+ }
virtual json to_json() override {
- return oaicompat == OAICOMPAT_TYPE_EMBEDDING ? to_json_oaicompat() : to_json_non_oaicompat();
+ return oaicompat == OAICOMPAT_TYPE_EMBEDDING
+ ? to_json_oaicompat()
+ : to_json_non_oaicompat();
}
json to_json_non_oaicompat() {
- return json{
- {"index", index},
+ return json {
+ {"index", index},
{"embedding", embedding},
};
}
json to_json_oaicompat() {
- return json{
- {"index", index},
- {"embedding", embedding[0]},
+ return json {
+ {"index", index},
+ {"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
};
}
@@ -1000,52 +1029,54 @@ struct server_task_result_rerank : server_task_result {
int32_t n_tokens;
- virtual int get_index() override { return index; }
+ virtual int get_index() override {
+ return index;
+ }
virtual json to_json() override {
- return json{
- {"index", index},
- {"score", score},
+ return json {
+ {"index", index},
+ {"score", score},
{"tokens_evaluated", n_tokens},
};
}
};
// this function maybe used outside of server_task_result_error
-static json format_error_response(const std::string &message, const enum error_type type) {
+static json format_error_response(const std::string & message, const enum error_type type) {
std::string type_str;
int code = 500;
switch (type) {
- case ERROR_TYPE_INVALID_REQUEST:
- type_str = "invalid_request_error";
- code = 400;
- break;
- case ERROR_TYPE_AUTHENTICATION:
- type_str = "authentication_error";
- code = 401;
- break;
- case ERROR_TYPE_NOT_FOUND:
- type_str = "not_found_error";
- code = 404;
- break;
- case ERROR_TYPE_SERVER:
- type_str = "server_error";
- code = 500;
- break;
- case ERROR_TYPE_PERMISSION:
- type_str = "permission_error";
- code = 403;
- break;
- case ERROR_TYPE_NOT_SUPPORTED:
- type_str = "not_supported_error";
- code = 501;
- break;
- case ERROR_TYPE_UNAVAILABLE:
- type_str = "unavailable_error";
- code = 503;
- break;
- }
- return json{
+ case ERROR_TYPE_INVALID_REQUEST:
+ type_str = "invalid_request_error";
+ code = 400;
+ break;
+ case ERROR_TYPE_AUTHENTICATION:
+ type_str = "authentication_error";
+ code = 401;
+ break;
+ case ERROR_TYPE_NOT_FOUND:
+ type_str = "not_found_error";
+ code = 404;
+ break;
+ case ERROR_TYPE_SERVER:
+ type_str = "server_error";
+ code = 500;
+ break;
+ case ERROR_TYPE_PERMISSION:
+ type_str = "permission_error";
+ code = 403;
+ break;
+ case ERROR_TYPE_NOT_SUPPORTED:
+ type_str = "not_supported_error";
+ code = 501;
+ break;
+ case ERROR_TYPE_UNAVAILABLE:
+ type_str = "unavailable_error";
+ code = 503;
+ break;
+ }
+ return json {
{"code", code},
{"message", message},
{"type", type_str},
@@ -1057,9 +1088,13 @@ struct server_task_result_error : server_task_result {
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
- virtual bool is_error() override { return true; }
+ virtual bool is_error() override {
+ return true;
+ }
- virtual json to_json() override { return format_error_response(err_msg, err_type); }
+ virtual json to_json() override {
+ return format_error_response(err_msg, err_type);
+ }
};
struct server_task_result_metrics : server_task_result {
@@ -1073,17 +1108,17 @@ struct server_task_result_metrics : server_task_result {
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
uint64_t n_prompt_tokens_processed_total = 0;
- uint64_t t_prompt_processing_total = 0;
- uint64_t n_tokens_predicted_total = 0;
- uint64_t t_tokens_generation_total = 0;
+ uint64_t t_prompt_processing_total = 0;
+ uint64_t n_tokens_predicted_total = 0;
+ uint64_t t_tokens_generation_total = 0;
uint64_t n_prompt_tokens_processed = 0;
- uint64_t t_prompt_processing = 0;
+ uint64_t t_prompt_processing = 0;
- uint64_t n_tokens_predicted = 0;
+ uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
- uint64_t n_decode_total = 0;
+ uint64_t n_decode_total = 0;
uint64_t n_busy_slots_total = 0;
// while we can also use std::vector this requires copying the slot object which can be quite messy
@@ -1091,29 +1126,29 @@ struct server_task_result_metrics : server_task_result {
json slots_data = json::array();
virtual json to_json() override {
- return json{
- {"idle", n_idle_slots},
- {"processing", n_processing_slots},
- {"deferred", n_tasks_deferred},
- {"t_start", t_start},
+ return json {
+ { "idle", n_idle_slots },
+ { "processing", n_processing_slots },
+ { "deferred", n_tasks_deferred },
+ { "t_start", t_start },
- {"n_prompt_tokens_processed_total", n_prompt_tokens_processed_total},
- {"t_tokens_generation_total", t_tokens_generation_total},
- {"n_tokens_predicted_total", n_tokens_predicted_total},
- {"t_prompt_processing_total", t_prompt_processing_total},
+ { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
+ { "t_tokens_generation_total", t_tokens_generation_total },
+ { "n_tokens_predicted_total", n_tokens_predicted_total },
+ { "t_prompt_processing_total", t_prompt_processing_total },
- {"n_prompt_tokens_processed", n_prompt_tokens_processed},
- {"t_prompt_processing", t_prompt_processing},
- {"n_tokens_predicted", n_tokens_predicted},
- {"t_tokens_generation", t_tokens_generation},
+ { "n_prompt_tokens_processed", n_prompt_tokens_processed },
+ { "t_prompt_processing", t_prompt_processing },
+ { "n_tokens_predicted", n_tokens_predicted },
+ { "t_tokens_generation", t_tokens_generation },
- {"n_decode_total", n_decode_total},
- {"n_busy_slots_total", n_busy_slots_total},
+ { "n_decode_total", n_decode_total },
+ { "n_busy_slots_total", n_busy_slots_total },
- {"kv_cache_tokens_count", kv_cache_tokens_count},
- {"kv_cache_used_cells", kv_cache_used_cells},
+ { "kv_cache_tokens_count", kv_cache_tokens_count },
+ { "kv_cache_used_cells", kv_cache_used_cells },
- {"slots", slots_data},
+ { "slots", slots_data },
};
}
};
@@ -1128,17 +1163,24 @@ struct server_task_result_slot_save_load : server_task_result {
virtual json to_json() override {
if (is_save) {
- return json{
- {"id_slot", id_slot}, {"filename", filename}, {"n_saved", n_tokens},
- {"n_written", n_bytes}, {"timings", {{"save_ms", t_ms}}},
+ return json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_saved", n_tokens },
+ { "n_written", n_bytes },
+ { "timings", {
+ { "save_ms", t_ms }
+ }},
};
} else {
- return json{
- {"id_slot", id_slot},
- {"filename", filename},
- {"n_restored", n_tokens},
- {"n_read", n_bytes},
- {"timings", {{"restore_ms", t_ms}}},
+ return json {
+ { "id_slot", id_slot },
+ { "filename", filename },
+ { "n_restored", n_tokens },
+ { "n_read", n_bytes },
+ { "timings", {
+ { "restore_ms", t_ms }
+ }},
};
}
}
@@ -1148,15 +1190,17 @@ struct server_task_result_slot_erase : server_task_result {
size_t n_erased;
virtual json to_json() override {
- return json{
- {"id_slot", id_slot},
- {"n_erased", n_erased},
+ return json {
+ { "id_slot", id_slot },
+ { "n_erased", n_erased },
};
}
};
struct server_task_result_apply_lora : server_task_result {
- virtual json to_json() override { return json{{"success", true}}; }
+ virtual json to_json() override {
+ return json {{ "success", true }};
+ }
};
struct server_slot {
@@ -1168,10 +1212,10 @@ struct server_slot {
llama_batch batch_spec = {};
- llama_context *ctx = nullptr;
- llama_context *ctx_dft = nullptr;
+ llama_context * ctx = nullptr;
+ llama_context * ctx_dft = nullptr;
- common_speculative *spec = nullptr;
+ common_speculative * spec = nullptr;
std::vector lora;
@@ -1186,15 +1230,15 @@ struct server_slot {
int64_t t_last_used = -1;
// generation props
- int32_t n_ctx = 0; // context size per slot
- int32_t n_past = 0;
- int32_t n_decoded = 0;
+ int32_t n_ctx = 0; // context size per slot
+ int32_t n_past = 0;
+ int32_t n_decoded = 0;
int32_t n_remaining = -1;
- int32_t i_batch = -1;
- int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
+ int32_t i_batch = -1;
+ int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
- int32_t n_prompt_tokens = 0;
+ int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0;
// input prompt tokens
@@ -1202,7 +1246,7 @@ struct server_slot {
size_t last_nl_pos = 0;
- std::string generated_text;
+ std::string generated_text;
llama_tokens generated_tokens;
llama_tokens cache_tokens;
@@ -1210,8 +1254,8 @@ struct server_slot {
std::vector generated_token_probs;
bool has_next_token = true;
- bool has_new_line = false;
- bool truncated = false;
+ bool has_new_line = false;
+ bool truncated = false;
stop_type stop;
std::string stopping_word;
@@ -1219,14 +1263,14 @@ struct server_slot {
// sampling
json json_schema;
- struct common_sampler *smpl = nullptr;
+ struct common_sampler * smpl = nullptr;
llama_token sampled;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
// stats
- size_t n_sent_text = 0; // number of sent text character
+ size_t n_sent_text = 0; // number of sent text character
int64_t t_start_process_prompt;
int64_t t_start_generation;
@@ -1239,16 +1283,16 @@ struct server_slot {
void reset() {
SLT_DBG(*this, "%s", "\n");
- n_prompt_tokens = 0;
- last_nl_pos = 0;
- generated_text = "";
- has_new_line = false;
- truncated = false;
- stop = STOP_TYPE_NONE;
- stopping_word = "";
- n_past = 0;
- n_sent_text = 0;
- task_type = SERVER_TASK_TYPE_COMPLETION;
+ n_prompt_tokens = 0;
+ last_nl_pos = 0;
+ generated_text = "";
+ has_new_line = false;
+ truncated = false;
+ stop = STOP_TYPE_NONE;
+ stopping_word = "";
+ n_past = 0;
+ n_sent_text = 0;
+ task_type = SERVER_TASK_TYPE_COMPLETION;
generated_tokens.clear();
generated_token_probs.clear();
@@ -1258,11 +1302,12 @@ struct server_slot {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
}
- bool can_batch_with(server_slot &other_slot) {
- return is_non_causal() == other_slot.is_non_causal() && are_lora_equal(lora, other_slot.lora);
+ bool can_batch_with(server_slot & other_slot) const {
+ return is_non_causal() == other_slot.is_non_causal()
+ && are_lora_equal(lora, other_slot.lora);
}
- bool has_budget(const common_params &global_params) {
+ bool has_budget(const common_params & global_params) {
if (params.n_predict == -1 && global_params.n_predict == -1) {
return true; // limitless
}
@@ -1278,11 +1323,15 @@ struct server_slot {
return n_remaining > 0; // no budget
}
- bool is_processing() const { return state != SLOT_STATE_IDLE; }
+ bool is_processing() const {
+ return state != SLOT_STATE_IDLE;
+ }
- bool can_speculate() const { return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; }
+ bool can_speculate() const {
+ return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
+ }
- void add_token(const completion_token_output &token) {
+ void add_token(const completion_token_output & token) {
if (!is_processing()) {
SLT_WRN(*this, "%s", "slot is not processing\n");
return;
@@ -1316,14 +1365,14 @@ struct server_slot {
return timings;
}
- size_t find_stopping_strings(const std::string &text, const size_t last_token_size, bool is_full_stop) {
+ size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos;
- for (const std::string &word : params.antiprompt) {
+ for (const std::string & word : params.antiprompt) {
size_t pos;
if (is_full_stop) {
- const size_t tmp = word.size() + last_token_size;
+ const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
@@ -1334,8 +1383,8 @@ struct server_slot {
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
if (is_full_stop) {
- stop = STOP_TYPE_WORD;
- stopping_word = word;
+ stop = STOP_TYPE_WORD;
+ stopping_word = word;
has_next_token = false;
}
stop_pos = pos;
@@ -1346,10 +1395,10 @@ struct server_slot {
}
void print_timings() const {
- const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
+ const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
- const double t_gen = t_token_generation / n_decoded;
+ const double t_gen = t_token_generation / n_decoded;
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
SLT_INF(*this,
@@ -1357,29 +1406,30 @@ struct server_slot {
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
" total time = %10.2f ms / %5d tokens\n",
- t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, t_token_generation,
- n_decoded, t_gen, n_gen_second, t_prompt_processing + t_token_generation,
- n_prompt_tokens_processed + n_decoded);
+ t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
+ t_token_generation, n_decoded, t_gen, n_gen_second,
+ t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
}
json to_json() const {
- return json{
- {"id", id},
- {"id_task", id_task},
- {"n_ctx", n_ctx},
- {"speculative", can_speculate()},
+ return json {
+ {"id", id},
+ {"id_task", id_task},
+ {"n_ctx", n_ctx},
+ {"speculative", can_speculate()},
{"is_processing", is_processing()},
- {"non_causal", is_non_causal()},
- {"params", params.to_json()},
- {"prompt", common_detokenize(ctx, prompt_tokens)},
+ {"non_causal", is_non_causal()},
+ {"params", params.to_json()},
+ {"prompt", common_detokenize(ctx, prompt_tokens)},
{"next_token",
- {
- {"has_next_token", has_next_token},
- {"has_new_line", has_new_line},
- {"n_remain", n_remaining},
- {"n_decoded", n_decoded},
- {"stopping_word", stopping_word},
- }},
+ {
+ {"has_next_token", has_next_token},
+ {"has_new_line", has_new_line},
+ {"n_remain", n_remaining},
+ {"n_decoded", n_decoded},
+ {"stopping_word", stopping_word},
+ }
+ },
};
}
};
@@ -1388,38 +1438,40 @@ struct server_metrics {
int64_t t_start = 0;
uint64_t n_prompt_tokens_processed_total = 0;
- uint64_t t_prompt_processing_total = 0;
- uint64_t n_tokens_predicted_total = 0;
- uint64_t t_tokens_generation_total = 0;
+ uint64_t t_prompt_processing_total = 0;
+ uint64_t n_tokens_predicted_total = 0;
+ uint64_t t_tokens_generation_total = 0;
uint64_t n_prompt_tokens_processed = 0;
- uint64_t t_prompt_processing = 0;
+ uint64_t t_prompt_processing = 0;
- uint64_t n_tokens_predicted = 0;
+ uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
- uint64_t n_decode_total = 0;
+ uint64_t n_decode_total = 0;
uint64_t n_busy_slots_total = 0;
- void init() { t_start = ggml_time_us(); }
+ void init() {
+ t_start = ggml_time_us();
+ }
- void on_prompt_eval(const server_slot &slot) {
+ void on_prompt_eval(const server_slot & slot) {
n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
- n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
- t_prompt_processing += slot.t_prompt_processing;
- t_prompt_processing_total += slot.t_prompt_processing;
+ n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
+ t_prompt_processing += slot.t_prompt_processing;
+ t_prompt_processing_total += slot.t_prompt_processing;
}
- void on_prediction(const server_slot &slot) {
- n_tokens_predicted_total += slot.n_decoded;
- n_tokens_predicted += slot.n_decoded;
- t_tokens_generation += slot.t_token_generation;
- t_tokens_generation_total += slot.t_token_generation;
+ void on_prediction(const server_slot & slot) {
+ n_tokens_predicted_total += slot.n_decoded;
+ n_tokens_predicted += slot.n_decoded;
+ t_tokens_generation += slot.t_token_generation;
+ t_tokens_generation_total += slot.t_token_generation;
}
- void on_decoded(const std::vector &slots) {
+ void on_decoded(const std::vector & slots) {
n_decode_total++;
- for (const auto &slot : slots) {
+ for (const auto & slot : slots) {
if (slot.is_processing()) {
n_busy_slots_total++;
}
@@ -1428,9 +1480,9 @@ struct server_metrics {
void reset_bucket() {
n_prompt_tokens_processed = 0;
- t_prompt_processing = 0;
- n_tokens_predicted = 0;
- t_tokens_generation = 0;
+ t_prompt_processing = 0;
+ n_tokens_predicted = 0;
+ t_tokens_generation = 0;
}
};
@@ -1447,7 +1499,7 @@ struct server_queue {
// callback functions
std::function callback_new_task;
- std::function callback_update_slots;
+ std::function callback_update_slots;
// Add a new task to the end of the queue
int post(server_task task, bool front = false) {
@@ -1468,9 +1520,9 @@ struct server_queue {
}
// multi-task version of post()
- int post(std::vector &tasks, bool front = false) {
+ int post(std::vector & tasks, bool front = false) {
std::unique_lock lock(mutex_tasks);
- for (auto &task : tasks) {
+ for (auto & task : tasks) {
if (task.id == -1) {
task.id = id++;
}
@@ -1478,7 +1530,7 @@ struct server_queue {
if (task.type == SERVER_TASK_TYPE_CANCEL) {
cleanup_pending_task(task.id_target);
}
- QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int)tasks.size(), front);
+ QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
if (front) {
queue_tasks.push_front(std::move(task));
} else {
@@ -1505,10 +1557,14 @@ struct server_queue {
}
// Register function to process a new task
- void on_new_task(std::function callback) { callback_new_task = std::move(callback); }
+ void on_new_task(std::function callback) {
+ callback_new_task = std::move(callback);
+ }
// Register the function to be called when all slots data is ready to be processed
- void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); }
+ void on_update_slots(std::function callback) {
+ callback_update_slots = std::move(callback);
+ }
// Call when the state of one slot is changed, it will move one task from deferred to main queue
void pop_deferred_task() {
@@ -1571,19 +1627,26 @@ struct server_queue {
return;
}
if (queue_tasks.empty()) {
- condition_tasks.wait(lock, [&] { return (!queue_tasks.empty() || !running); });
+ condition_tasks.wait(lock, [&]{
+ return (!queue_tasks.empty() || !running);
+ });
}
}
}
}
- private:
+private:
void cleanup_pending_task(int id_target) {
// no need lock because this is called exclusively by post()
- auto rm_func = [id_target](const server_task &task) { return task.id_target == id_target; };
- queue_tasks.erase(std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), queue_tasks.end());
- queue_tasks_deferred.erase(std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
- queue_tasks_deferred.end());
+ auto rm_func = [id_target](const server_task & task) {
+ return task.id_target == id_target;
+ };
+ queue_tasks.erase(
+ std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
+ queue_tasks.end());
+ queue_tasks_deferred.erase(
+ std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
+ queue_tasks_deferred.end());
}
};
@@ -1599,51 +1662,51 @@ struct server_response {
// add the id_task to the list of tasks waiting for response
void add_waiting_task_id(int id_task) {
- SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task,
- (int)waiting_task_ids.size());
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock lock(mutex_results);
waiting_task_ids.insert(id_task);
}
- void add_waiting_tasks(const std::vector &tasks) {
+ void add_waiting_tasks(const std::vector & tasks) {
std::unique_lock lock(mutex_results);
- for (const auto &task : tasks) {
- SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id,
- (int)waiting_task_ids.size());
+ for (const auto & task : tasks) {
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
waiting_task_ids.insert(task.id);
}
}
// when the request is finished, we can remove task associated with it
void remove_waiting_task_id(int id_task) {
- SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task,
- (int)waiting_task_ids.size());
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
std::unique_lock lock(mutex_results);
waiting_task_ids.erase(id_task);
// make sure to clean up all pending results
- queue_results.erase(std::remove_if(queue_results.begin(), queue_results.end(),
- [id_task](const server_task_result_ptr &res) { return res->id == id_task; }),
- queue_results.end());
+ queue_results.erase(
+ std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
+ return res->id == id_task;
+ }),
+ queue_results.end());
}
- void remove_waiting_task_ids(const std::unordered_set &id_tasks) {
+ void remove_waiting_task_ids(const std::unordered_set & id_tasks) {
std::unique_lock lock(mutex_results);
- for (const auto &id_task : id_tasks) {
- SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task,
- (int)waiting_task_ids.size());
+ for (const auto & id_task : id_tasks) {
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
waiting_task_ids.erase(id_task);
}
}
// This function blocks the thread until there is a response for one of the id_tasks
- server_task_result_ptr recv(const std::unordered_set &id_tasks) {
+ server_task_result_ptr recv(const std::unordered_set & id_tasks) {
while (true) {
std::unique_lock lock(mutex_results);
- condition_results.wait(lock, [&] { return !queue_results.empty(); });
+ condition_results.wait(lock, [&]{
+ return !queue_results.empty();
+ });
for (size_t i = 0; i < queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
@@ -1659,11 +1722,11 @@ struct server_response {
// same as recv(), but have timeout in seconds
// if timeout is reached, nullptr is returned
- server_task_result_ptr recv_with_timeout(const std::unordered_set &id_tasks, int timeout) {
+ server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) {
while (true) {
std::unique_lock lock(mutex_results);
- for (int i = 0; i < (int)queue_results.size(); i++) {
+ for (int i = 0; i < (int) queue_results.size(); i++) {
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
server_task_result_ptr res = std::move(queue_results[i]);
queue_results.erase(queue_results.begin() + i);
@@ -1687,11 +1750,11 @@ struct server_response {
}
// Send a new result to a waiting id_task
- void send(server_task_result_ptr &&result) {
+ void send(server_task_result_ptr && result) {
SRV_DBG("sending result for task id = %d\n", result->id);
std::unique_lock lock(mutex_results);
- for (const auto &id_task : waiting_task_ids) {
+ for (const auto & id_task : waiting_task_ids) {
if (result->id == id_task) {
SRV_DBG("task id = %d pushed to result queue\n", result->id);
@@ -1710,20 +1773,20 @@ struct server_context {
common_init_result llama_init;
common_init_result llama_init_dft;
- llama_model *model = nullptr;
- llama_context *ctx = nullptr;
+ llama_model * model = nullptr;
+ llama_context * ctx = nullptr;
- const llama_vocab *vocab = nullptr;
+ const llama_vocab * vocab = nullptr;
- llama_model *model_dft = nullptr;
+ llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
llama_batch batch = {};
bool clean_kv_cache = true;
- bool add_bos_token = true;
- bool has_eos_token = false;
+ bool add_bos_token = true;
+ bool has_eos_token = false;
int32_t n_ctx; // total context for all clients / slots
@@ -1731,7 +1794,7 @@ struct server_context {
std::vector slots;
json default_generation_settings_for_props;
- server_queue queue_tasks;
+ server_queue queue_tasks;
server_response queue_results;
server_metrics metrics;
@@ -1743,7 +1806,7 @@ struct server_context {
~server_context() {
// Clear any sampling context
- for (server_slot &slot : slots) {
+ for (server_slot & slot : slots) {
common_sampler_free(slot.smpl);
slot.smpl = nullptr;
@@ -1759,7 +1822,7 @@ struct server_context {
llama_batch_free(batch);
}
- bool load_model(const common_params ¶ms) {
+ bool load_model(const common_params & params) {
SRV_INF("loading model '%s'\n", params.model.c_str());
params_base = params;
@@ -1767,7 +1830,7 @@ struct server_context {
llama_init = common_init_from_params(params_base);
model = llama_init.model.get();
- ctx = llama_init.context.get();
+ ctx = llama_init.context.get();
if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
@@ -1786,15 +1849,18 @@ struct server_context {
auto params_dft = params_base;
- params_dft.devices = params_base.speculative.devices;
- params_dft.hf_file = params_base.speculative.hf_file;
- params_dft.hf_repo = params_base.speculative.hf_repo;
- params_dft.model = params_base.speculative.model;
- params_dft.model_url = params_base.speculative.model_url;
- params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel
- : params_base.speculative.n_ctx;
+ params_dft.devices = params_base.speculative.devices;
+ params_dft.hf_file = params_base.speculative.hf_file;
+ params_dft.hf_repo = params_base.speculative.hf_repo;
+ params_dft.model = params_base.speculative.model;
+ params_dft.model_url = params_base.speculative.model_url;
+ params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
- params_dft.n_parallel = 1;
+ params_dft.n_parallel = 1;
+
+ // force F16 KV cache for the draft model for extra performance
+ params_dft.cache_type_k = GGML_TYPE_F16;
+ params_dft.cache_type_v = GGML_TYPE_F16;
llama_init_dft = common_init_from_params(params_dft);
@@ -1806,8 +1872,7 @@ struct server_context {
}
if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) {
- SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n",
- params_base.speculative.model.c_str(), params_base.model.c_str());
+ SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
return false;
}
@@ -1817,10 +1882,6 @@ struct server_context {
cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft;
- // force F16 KV cache for the draft model for extra performance
- cparams_dft.type_k = GGML_TYPE_F16;
- cparams_dft.type_v = GGML_TYPE_F16;
-
// the context is not needed - we will create one for each slot
llama_init_dft.context.reset();
}
@@ -1828,10 +1889,9 @@ struct server_context {
chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja);
- } catch (const std::exception &e) {
- SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. "
- "This may cause the model to output suboptimal responses\n",
- __func__);
+ } catch (const std::exception & e) {
+ SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
+ SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_init(model, "chatml");
}
@@ -1871,7 +1931,9 @@ struct server_context {
slot.params.sampling = params_base.sampling;
- slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); };
+ slot.callback_on_release = [this](int) {
+ queue_tasks.pop_deferred_task();
+ };
slot.reset();
@@ -1881,8 +1943,7 @@ struct server_context {
default_generation_settings_for_props = slots[0].to_json();
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
- // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not
- // used)
+ // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{
const int32_t n_batch = llama_n_batch(ctx);
@@ -1893,8 +1954,8 @@ struct server_context {
metrics.init();
}
- server_slot *get_slot_by_id(int id) {
- for (server_slot &slot : slots) {
+ server_slot * get_slot_by_id(int id) {
+ for (server_slot & slot : slots) {
if (slot.id == id) {
return &slot;
}
@@ -1903,15 +1964,15 @@ struct server_context {
return nullptr;
}
- server_slot *get_available_slot(const server_task &task) {
- server_slot *ret = nullptr;
+ server_slot * get_available_slot(const server_task & task) {
+ server_slot * ret = nullptr;
// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
int lcs_len = 0;
float similarity = 0;
- for (server_slot &slot : slots) {
+ for (server_slot & slot : slots) {
// skip the slot if it is not available
if (slot.is_processing()) {
continue;
@@ -1944,7 +2005,7 @@ struct server_context {
// find the slot that has been least recently used
if (ret == nullptr) {
int64_t t_last = ggml_time_us();
- for (server_slot &slot : slots) {
+ for (server_slot & slot : slots) {
// skip the slot if it is not available
if (slot.is_processing()) {
continue;
@@ -1965,12 +2026,24 @@ struct server_context {
return ret;
}
- bool launch_slot_with_task(server_slot &slot, const server_task &task) {
+ bool can_be_detokenized(const struct llama_context * ctx, const std::vector & tokens) {
+ const llama_model * model = llama_get_model(ctx);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
+ const int32_t n_vocab = llama_vocab_n_tokens(vocab);
+ for (const auto & token : tokens) {
+ if (token < 0 || token >= n_vocab) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot.reset();
- slot.id_task = task.id;
- slot.index = task.index;
- slot.task_type = task.type;
- slot.params = std::move(task.params);
+ slot.id_task = task.id;
+ slot.index = task.index;
+ slot.task_type = task.type;
+ slot.params = std::move(task.params);
slot.prompt_tokens = std::move(task.prompt_tokens);
if (!are_lora_equal(task.params.lora, slot.lora)) {
@@ -1979,12 +2052,16 @@ struct server_context {
slot.lora = task.params.lora;
}
+ bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
+ if (!can_detokenize) {
+ send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ?
- SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict,
- slot.n_predict);
+ SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict);
slot.params.n_predict = slot.n_predict;
}
@@ -2022,11 +2099,11 @@ struct server_context {
SRV_DBG("%s", "clearing KV cache\n");
// clear the entire KV cache
- llama_kv_cache_clear(ctx);
+ llama_kv_self_clear(ctx);
clean_kv_cache = false;
}
- bool process_token(completion_token_output &result, server_slot &slot) {
+ bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = result.text_to_send;
slot.sampled = result.tok;
@@ -2049,7 +2126,9 @@ struct server_context {
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
if (stop_pos != std::string::npos) {
- slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end());
+ slot.generated_text.erase(
+ slot.generated_text.begin() + pos + stop_pos,
+ slot.generated_text.end());
pos = std::min(slot.n_sent_text, slot.generated_text.size());
} else if (slot.has_next_token) {
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
@@ -2078,23 +2157,13 @@ struct server_context {
// check the limits
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
- slot.stop = STOP_TYPE_LIMIT;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
}
if (slot.has_new_line) {
- // if we have already seen a new line, we stop after a certain time limit
- if (slot.params.t_max_predict_ms > 0 &&
- (ggml_time_us() - slot.t_start_generation > 1000.0f * slot.params.t_max_predict_ms)) {
- slot.stop = STOP_TYPE_LIMIT;
- slot.has_next_token = false;
-
- SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded,
- (int)slot.params.t_max_predict_ms);
- }
-
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
if (slot.params.n_indent > 0) {
// check the current indentation
@@ -2103,21 +2172,19 @@ struct server_context {
size_t pos = slot.last_nl_pos;
int n_indent = 0;
- while (pos < slot.generated_text.size() &&
- (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
+ while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
n_indent++;
pos++;
}
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
- slot.stop = STOP_TYPE_LIMIT;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
// cut the last line
slot.generated_text.erase(pos, std::string::npos);
- SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded,
- n_indent);
+ SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
}
}
@@ -2135,22 +2202,28 @@ struct server_context {
// check if there is a new line in the generated text
if (result.text_to_send.find('\n') != std::string::npos) {
slot.has_new_line = true;
+
+ // if we have seen a new line, we stop after a certain time limit, but only upon another new line
+ if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
+ slot.stop = STOP_TYPE_LIMIT;
+ slot.has_next_token = false;
+
+ SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
+ }
}
// if context shift is disabled, we stop when it reaches the context limit
if (slot.n_past >= slot.n_ctx) {
- slot.truncated = true;
- slot.stop = STOP_TYPE_LIMIT;
+ slot.truncated = true;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
- SLT_DBG(slot,
- "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = "
- "%d, n_ctx = %d\n",
+ SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
}
if (llama_vocab_is_eog(vocab, result.tok)) {
- slot.stop = STOP_TYPE_EOS;
+ slot.stop = STOP_TYPE_EOS;
slot.has_next_token = false;
SLT_DBG(slot, "%s", "stopped by EOS\n");
@@ -2159,8 +2232,8 @@ struct server_context {
const auto n_ctx_train = llama_model_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
- slot.truncated = true;
- slot.stop = STOP_TYPE_LIMIT;
+ slot.truncated = true;
+ slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; // stop prediction
SLT_WRN(slot,
@@ -2169,18 +2242,16 @@ struct server_context {
slot.params.n_predict, n_ctx_train);
}
- SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining,
- result.tok, token_str.c_str());
+ SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
return slot.has_next_token; // continue
}
- void populate_token_probs(const server_slot &slot, completion_token_output &result, bool post_sampling,
- bool special, int idx) {
+ void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.params.sampling.n_probs;
size_t n_vocab = llama_vocab_n_tokens(vocab);
if (post_sampling) {
- const auto *cur_p = common_sampler_get_candidates(slot.smpl);
+ const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size;
// set probability for sampled token
@@ -2194,8 +2265,11 @@ struct server_context {
// set probability for top n_probs tokens
result.probs.reserve(max_probs);
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
- result.probs.push_back(
- {cur_p->data[i].id, common_token_to_piece(ctx, cur_p->data[i].id, special), cur_p->data[i].p});
+ result.probs.push_back({
+ cur_p->data[i].id,
+ common_token_to_piece(ctx, cur_p->data[i].id, special),
+ cur_p->data[i].p
+ });
}
} else {
// TODO: optimize this with min-p optimization
@@ -2213,45 +2287,49 @@ struct server_context {
// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
- result.probs.push_back({cur[i].id, common_token_to_piece(ctx, cur[i].id, special), cur[i].p});
+ result.probs.push_back({
+ cur[i].id,
+ common_token_to_piece(ctx, cur[i].id, special),
+ cur[i].p
+ });
}
}
}
- void send_error(const server_task &task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) {
+ void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(task.id, error, type);
}
- void send_error(const server_slot &slot, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) {
+ void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
send_error(slot.id_task, error, type);
}
- void send_error(const int id_task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) {
+ void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
auto res = std::make_unique();
- res->id = id_task;
+ res->id = id_task;
res->err_type = type;
- res->err_msg = error;
+ res->err_msg = error;
queue_results.send(std::move(res));
}
- void send_partial_response(server_slot &slot, const completion_token_output &tkn) {
+ void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
auto res = std::make_unique();
- res->id = slot.id_task;
- res->index = slot.index;
+ res->id = slot.id_task;
+ res->index = slot.index;
res->content = tkn.text_to_send;
- res->tokens = {tkn.tok};
+ res->tokens = { tkn.tok };
- res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.n_prompt_tokens;
+ res->n_decoded = slot.n_decoded;
+ res->n_prompt_tokens = slot.n_prompt_tokens;
res->post_sampling_probs = slot.params.post_sampling_probs;
- res->verbose = slot.params.verbose;
- res->oaicompat = slot.params.oaicompat;
- res->oaicompat_model = slot.params.oaicompat_model;
+ res->verbose = slot.params.verbose;
+ res->oaicompat = slot.params.oaicompat;
+ res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
// populate res.probs_output
@@ -2267,32 +2345,32 @@ struct server_context {
queue_results.send(std::move(res));
}
- void send_final_response(server_slot &slot) {
+ void send_final_response(server_slot & slot) {
auto res = std::make_unique();
- res->id = slot.id_task;
- res->id_slot = slot.id;
-
- res->index = slot.index;
- res->content = std::move(slot.generated_text);
- res->tokens = std::move(slot.generated_tokens);
- res->timings = slot.get_timings();
- res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
+ res->id = slot.id_task;
+ res->id_slot = slot.id;
+
+ res->index = slot.index;
+ res->content = std::move(slot.generated_text);
+ res->tokens = std::move(slot.generated_tokens);
+ res->timings = slot.get_timings();
+ res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
res->response_fields = std::move(slot.params.response_fields);
- res->truncated = slot.truncated;
- res->n_decoded = slot.n_decoded;
- res->n_prompt_tokens = slot.n_prompt_tokens;
- res->n_tokens_cached = slot.n_past;
- res->has_new_line = slot.has_new_line;
- res->stopping_word = slot.stopping_word;
- res->stop = slot.stop;
+ res->truncated = slot.truncated;
+ res->n_decoded = slot.n_decoded;
+ res->n_prompt_tokens = slot.n_prompt_tokens;
+ res->n_tokens_cached = slot.n_past;
+ res->has_new_line = slot.has_new_line;
+ res->stopping_word = slot.stopping_word;
+ res->stop = slot.stop;
res->post_sampling_probs = slot.params.post_sampling_probs;
- res->verbose = slot.params.verbose;
- res->stream = slot.params.stream;
- res->oaicompat = slot.params.oaicompat;
- res->oaicompat_model = slot.params.oaicompat_model;
- res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
+ res->verbose = slot.params.verbose;
+ res->stream = slot.params.stream;
+ res->oaicompat = slot.params.oaicompat;
+ res->oaicompat_model = slot.params.oaicompat_model;
+ res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@@ -2301,10 +2379,12 @@ struct server_context {
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
res->probs_output = std::vector(
- slot.generated_token_probs.begin(), slot.generated_token_probs.end() - safe_offset);
+ slot.generated_token_probs.begin(),
+ slot.generated_token_probs.end() - safe_offset);
} else {
- res->probs_output = std::vector(slot.generated_token_probs.begin(),
- slot.generated_token_probs.end());
+ res->probs_output = std::vector(
+ slot.generated_token_probs.begin(),
+ slot.generated_token_probs.end());
}
}
@@ -2313,11 +2393,11 @@ struct server_context {
queue_results.send(std::move(res));
}
- void send_embedding(const server_slot &slot, const llama_batch &batch) {
+ void send_embedding(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique();
- res->id = slot.id_task;
- res->index = slot.index;
- res->n_tokens = slot.n_prompt_tokens;
+ res->id = slot.id_task;
+ res->index = slot.index;
+ res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat;
const int n_embd = llama_model_n_embd(model);
@@ -2329,14 +2409,13 @@ struct server_context {
continue;
}
- const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
- SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i],
- batch.seq_id[i][0]);
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->embedding.push_back(std::vector(n_embd, 0.0f));
continue;
@@ -2348,7 +2427,7 @@ struct server_context {
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
res->embedding.push_back(embd_res);
} else {
- res->embedding.push_back({embd, embd + n_embd});
+ res->embedding.push_back({ embd, embd + n_embd });
}
}
@@ -2357,9 +2436,9 @@ struct server_context {
queue_results.send(std::move(res));
}
- void send_rerank(const server_slot &slot, const llama_batch &batch) {
+ void send_rerank(const server_slot & slot, const llama_batch & batch) {
auto res = std::make_unique();
- res->id = slot.id_task;
+ res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
@@ -2368,14 +2447,13 @@ struct server_context {
continue;
}
- const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
- SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i],
- batch.seq_id[i][0]);
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
res->score = -1e6;
continue;
@@ -2393,10 +2471,10 @@ struct server_context {
// Functions to create new task(s) and receive result(s)
//
- void cancel_tasks(const std::unordered_set &id_tasks) {
+ void cancel_tasks(const std::unordered_set & id_tasks) {
std::vector cancel_tasks;
cancel_tasks.reserve(id_tasks.size());
- for (const auto &id_task : id_tasks) {
+ for (const auto & id_task : id_tasks) {
SRV_WRN("cancel task, id_task = %d\n", id_task);
server_task task(SERVER_TASK_TYPE_CANCEL);
@@ -2409,10 +2487,11 @@ struct server_context {
}
// receive the results from task(s)
- void receive_multi_results(const std::unordered_set &id_tasks,
- const std::function &)> &result_handler,
- const std::function &error_handler,
- const std::function &is_connection_closed) {
+ void receive_multi_results(
+ const std::unordered_set & id_tasks,
+ const std::function&)> & result_handler,
+ const std::function & error_handler,
+ const std::function & is_connection_closed) {
std::vector results(id_tasks.size());
for (int i = 0; i < (int)id_tasks.size(); i++) {
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
@@ -2433,9 +2512,11 @@ struct server_context {
return;
}
- GGML_ASSERT(dynamic_cast(result.get()) != nullptr ||
- dynamic_cast(result.get()) != nullptr ||
- dynamic_cast(result.get()) != nullptr);
+ GGML_ASSERT(
+ dynamic_cast(result.get()) != nullptr
+ || dynamic_cast(result.get()) != nullptr
+ || dynamic_cast(result.get()) != nullptr
+ );
const size_t idx = result->get_index();
GGML_ASSERT(idx < results.size() && "index out of range");
results[idx] = std::move(result);
@@ -2444,10 +2525,11 @@ struct server_context {
}
// receive the results from task(s), in stream mode
- void receive_cmpl_results_stream(const std::unordered_set &id_tasks,
- const std::function &result_handler,
- const std::function &error_handler,
- const std::function &is_connection_closed) {
+ void receive_cmpl_results_stream(
+ const std::unordered_set & id_tasks,
+ const std::function & result_handler,
+ const std::function & error_handler,
+ const std::function & is_connection_closed) {
size_t n_finished = 0;
while (true) {
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
@@ -2467,8 +2549,10 @@ struct server_context {
return;
}
- GGML_ASSERT(dynamic_cast(result.get()) != nullptr ||
- dynamic_cast(result.get()) != nullptr);
+ GGML_ASSERT(
+ dynamic_cast(result.get()) != nullptr
+ || dynamic_cast(result.get()) != nullptr
+ );
if (!result_handler(result)) {
cancel_tasks(id_tasks);
break;
@@ -2488,203 +2572,208 @@ struct server_context {
void process_single_task(server_task task) {
switch (task.type) {
- case SERVER_TASK_TYPE_COMPLETION:
- case SERVER_TASK_TYPE_INFILL:
- case SERVER_TASK_TYPE_EMBEDDING:
- case SERVER_TASK_TYPE_RERANK: {
- const int id_slot = task.id_selected_slot;
-
- server_slot *slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
-
- if (slot == nullptr) {
- // if no slot is available, we defer this task for processing later
- SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
- if (slot->is_processing()) {
- // if requested slot is unavailable, we defer this task for processing later
- SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
-
- if (!launch_slot_with_task(*slot, task)) {
- SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
- break;
- }
- } break;
- case SERVER_TASK_TYPE_CANCEL: {
- // release slot linked with the task id
- for (auto &slot : slots) {
- if (slot.id_task == task.id_target) {
- slot.release();
- break;
- }
- }
- } break;
- case SERVER_TASK_TYPE_NEXT_RESPONSE: {
- // do nothing
- } break;
- case SERVER_TASK_TYPE_METRICS: {
- json slots_data = json::array();
+ case SERVER_TASK_TYPE_COMPLETION:
+ case SERVER_TASK_TYPE_INFILL:
+ case SERVER_TASK_TYPE_EMBEDDING:
+ case SERVER_TASK_TYPE_RERANK:
+ {
+ const int id_slot = task.id_selected_slot;
- int n_idle_slots = 0;
- int n_processing_slots = 0;
+ server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
- for (server_slot &slot : slots) {
- json slot_data = slot.to_json();
+ if (slot == nullptr) {
+ // if no slot is available, we defer this task for processing later
+ SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
+ if (slot->is_processing()) {
+ // if requested slot is unavailable, we defer this task for processing later
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
- if (slot.is_processing()) {
- n_processing_slots++;
- } else {
- n_idle_slots++;
- }
+ if (!launch_slot_with_task(*slot, task)) {
+ SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
+ break;
+ }
+ } break;
+ case SERVER_TASK_TYPE_CANCEL:
+ {
+ // release slot linked with the task id
+ for (auto & slot : slots) {
+ if (slot.id_task == task.id_target) {
+ slot.release();
+ break;
+ }
+ }
+ } break;
+ case SERVER_TASK_TYPE_NEXT_RESPONSE:
+ {
+ // do nothing
+ } break;
+ case SERVER_TASK_TYPE_METRICS:
+ {
+ json slots_data = json::array();
- slots_data.push_back(slot_data);
- }
- SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
+ int n_idle_slots = 0;
+ int n_processing_slots = 0;
- auto res = std::make_unique();
- res->id = task.id;
- res->slots_data = std::move(slots_data);
- res->n_idle_slots = n_idle_slots;
- res->n_processing_slots = n_processing_slots;
- res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
- res->t_start = metrics.t_start;
+ for (server_slot & slot : slots) {
+ json slot_data = slot.to_json();
- res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
- res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
+ if (slot.is_processing()) {
+ n_processing_slots++;
+ } else {
+ n_idle_slots++;
+ }
- res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
- res->t_prompt_processing_total = metrics.t_prompt_processing_total;
- res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
- res->t_tokens_generation_total = metrics.t_tokens_generation_total;
+ slots_data.push_back(slot_data);
+ }
+ SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
+
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->slots_data = std::move(slots_data);
+ res->n_idle_slots = n_idle_slots;
+ res->n_processing_slots = n_processing_slots;
+ res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
+ res->t_start = metrics.t_start;
+
+ res->kv_cache_tokens_count = llama_kv_self_n_tokens(ctx);
+ res->kv_cache_used_cells = llama_kv_self_used_cells(ctx);
+
+ res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
+ res->t_prompt_processing_total = metrics.t_prompt_processing_total;
+ res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
+ res->t_tokens_generation_total = metrics.t_tokens_generation_total;
+
+ res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
+ res->t_prompt_processing = metrics.t_prompt_processing;
+ res->n_tokens_predicted = metrics.n_tokens_predicted;
+ res->t_tokens_generation = metrics.t_tokens_generation;
+
+ res->n_decode_total = metrics.n_decode_total;
+ res->n_busy_slots_total = metrics.n_busy_slots_total;
+
+ if (task.metrics_reset_bucket) {
+ metrics.reset_bucket();
+ }
+ queue_results.send(std::move(res));
+ } break;
+ case SERVER_TASK_TYPE_SLOT_SAVE:
+ {
+ int id_slot = task.slot_action.slot_id;
+ server_slot * slot = get_slot_by_id(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ if (slot->is_processing()) {
+ // if requested slot is unavailable, we defer this task for processing later
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
- res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
- res->t_prompt_processing = metrics.t_prompt_processing;
- res->n_tokens_predicted = metrics.n_tokens_predicted;
- res->t_tokens_generation = metrics.t_tokens_generation;
+ const size_t token_count = slot->cache_tokens.size();
+ const int64_t t_start = ggml_time_us();
- res->n_decode_total = metrics.n_decode_total;
- res->n_busy_slots_total = metrics.n_busy_slots_total;
+ std::string filename = task.slot_action.filename;
+ std::string filepath = task.slot_action.filepath;
- if (task.metrics_reset_bucket) {
- metrics.reset_bucket();
- }
- queue_results.send(std::move(res));
- } break;
- case SERVER_TASK_TYPE_SLOT_SAVE: {
- int id_slot = task.slot_action.slot_id;
- server_slot *slot = get_slot_by_id(id_slot);
- if (slot == nullptr) {
- send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
- break;
- }
- if (slot->is_processing()) {
- // if requested slot is unavailable, we defer this task for processing later
- SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
- const size_t token_count = slot->cache_tokens.size();
- const int64_t t_start = ggml_time_us();
-
- std::string filename = task.slot_action.filename;
- std::string filepath = task.slot_action.filepath;
-
- const size_t nwrite =
- llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
-
- const int64_t t_end = ggml_time_us();
- const double t_save_ms = (t_end - t_start) / 1000.0;
-
- auto res = std::make_unique();
- res->id = task.id;
- res->id_slot = id_slot;
- res->filename = filename;
- res->is_save = true;
- res->n_tokens = token_count;
- res->n_bytes = nwrite;
- res->t_ms = t_save_ms;
- queue_results.send(std::move(res));
- } break;
- case SERVER_TASK_TYPE_SLOT_RESTORE: {
- int id_slot = task.slot_action.slot_id;
- server_slot *slot = get_slot_by_id(id_slot);
- if (slot == nullptr) {
- send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
- break;
- }
- if (slot->is_processing()) {
- // if requested slot is unavailable, we defer this task for processing later
- SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
+ const int64_t t_end = ggml_time_us();
+ const double t_save_ms = (t_end - t_start) / 1000.0;
- const int64_t t_start = ggml_time_us();
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->id_slot = id_slot;
+ res->filename = filename;
+ res->is_save = true;
+ res->n_tokens = token_count;
+ res->n_bytes = nwrite;
+ res->t_ms = t_save_ms;
+ queue_results.send(std::move(res));
+ } break;
+ case SERVER_TASK_TYPE_SLOT_RESTORE:
+ {
+ int id_slot = task.slot_action.slot_id;
+ server_slot * slot = get_slot_by_id(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ if (slot->is_processing()) {
+ // if requested slot is unavailable, we defer this task for processing later
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
- std::string filename = task.slot_action.filename;
- std::string filepath = task.slot_action.filepath;
+ const int64_t t_start = ggml_time_us();
- slot->cache_tokens.resize(slot->n_ctx);
- size_t token_count = 0;
- size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(),
- slot->cache_tokens.size(), &token_count);
- if (nread == 0) {
- slot->cache_tokens.resize(0);
- send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file",
- ERROR_TYPE_INVALID_REQUEST);
- break;
- }
- slot->cache_tokens.resize(token_count);
-
- const int64_t t_end = ggml_time_us();
- const double t_restore_ms = (t_end - t_start) / 1000.0;
-
- auto res = std::make_unique();
- res->id = task.id;
- res->id_slot = id_slot;
- res->filename = filename;
- res->is_save = false;
- res->n_tokens = token_count;
- res->n_bytes = nread;
- res->t_ms = t_restore_ms;
- queue_results.send(std::move(res));
- } break;
- case SERVER_TASK_TYPE_SLOT_ERASE: {
- int id_slot = task.slot_action.slot_id;
- server_slot *slot = get_slot_by_id(id_slot);
- if (slot == nullptr) {
- send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
- break;
- }
- if (slot->is_processing()) {
- // if requested slot is unavailable, we defer this task for processing later
- SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
- queue_tasks.defer(task);
- break;
- }
+ std::string filename = task.slot_action.filename;
+ std::string filepath = task.slot_action.filepath;
- // Erase token cache
- const size_t n_erased = slot->cache_tokens.size();
- llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
- slot->cache_tokens.clear();
+ slot->cache_tokens.resize(slot->n_ctx);
+ size_t token_count = 0;
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
+ if (nread == 0) {
+ slot->cache_tokens.resize(0);
+ send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ slot->cache_tokens.resize(token_count);
+
+ const int64_t t_end = ggml_time_us();
+ const double t_restore_ms = (t_end - t_start) / 1000.0;
+
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->id_slot = id_slot;
+ res->filename = filename;
+ res->is_save = false;
+ res->n_tokens = token_count;
+ res->n_bytes = nread;
+ res->t_ms = t_restore_ms;
+ queue_results.send(std::move(res));
+ } break;
+ case SERVER_TASK_TYPE_SLOT_ERASE:
+ {
+ int id_slot = task.slot_action.slot_id;
+ server_slot * slot = get_slot_by_id(id_slot);
+ if (slot == nullptr) {
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+ if (slot->is_processing()) {
+ // if requested slot is unavailable, we defer this task for processing later
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
+ queue_tasks.defer(task);
+ break;
+ }
- auto res = std::make_unique();
- res->id = task.id;
- res->id_slot = id_slot;
- res->n_erased = n_erased;
- queue_results.send(std::move(res));
- } break;
- case SERVER_TASK_TYPE_SET_LORA: {
- params_base.lora_adapters = std::move(task.set_lora);
- auto res = std::make_unique();
- res->id = task.id;
- queue_results.send(std::move(res));
- } break;
+ // Erase token cache
+ const size_t n_erased = slot->cache_tokens.size();
+ llama_kv_self_seq_rm(ctx, slot->id, -1, -1);
+ slot->cache_tokens.clear();
+
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->id_slot = id_slot;
+ res->n_erased = n_erased;
+ queue_results.send(std::move(res));
+ } break;
+ case SERVER_TASK_TYPE_SET_LORA:
+ {
+ params_base.lora_adapters = std::move(task.set_lora);
+ auto res = std::make_unique();
+ res->id = task.id;
+ queue_results.send(std::move(res));
+ } break;
}
}
@@ -2693,7 +2782,7 @@ struct server_context {
{
bool all_idle = true;
- for (auto &slot : slots) {
+ for (auto & slot : slots) {
if (slot.is_processing()) {
all_idle = false;
break;
@@ -2720,7 +2809,7 @@ struct server_context {
// apply context-shift if needed
// TODO: simplify and improve
- for (server_slot &slot : slots) {
+ for (server_slot & slot : slots) {
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
if (!params_base.ctx_shift) {
// this check is redundant (for good)
@@ -2731,15 +2820,14 @@ struct server_context {
}
// Shift context
- const int n_keep = slot.params.n_keep + add_bos_token;
- const int n_left = slot.n_past - n_keep;
+ const int n_keep = slot.params.n_keep + add_bos_token;
+ const int n_left = slot.n_past - n_keep;
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
- SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left,
- n_discard);
+ SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
- llama_kv_cache_seq_rm(ctx, slot.id, n_keep, n_keep + n_discard);
- llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
+ llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
+ llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -2759,15 +2847,14 @@ struct server_context {
common_batch_clear(batch);
// track if given slot can be batched with slots already in the batch
- server_slot *slot_batched = nullptr;
+ server_slot * slot_batched = nullptr;
- auto accept_special_token = [&](server_slot &slot, llama_token token) {
- return params_base.special ||
- slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
+ auto accept_special_token = [&](server_slot & slot, llama_token token) {
+ return params_base.special || slot.params.sampling.preserved_tokens.find(token) != slot.params.sampling.preserved_tokens.end();
};
// frist, add sampled tokens from any ongoing sequences
- for (auto &slot : slots) {
+ for (auto & slot : slots) {
if (slot.state != SLOT_STATE_GENERATING) {
continue;
}
@@ -2781,7 +2868,7 @@ struct server_context {
slot.i_batch = batch.n_tokens;
- common_batch_add(batch, slot.sampled, slot.n_past, {slot.id}, true);
+ common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
slot.n_past += 1;
@@ -2790,16 +2877,16 @@ struct server_context {
}
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
- slot.n_ctx, slot.n_past, (int)slot.cache_tokens.size(), slot.truncated);
+ slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
}
// process in chunks of params.n_batch
- int32_t n_batch = llama_n_batch(ctx);
+ int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
- for (auto &slot : slots) {
+ for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
if (!slot_batched) {
@@ -2811,7 +2898,7 @@ struct server_context {
// this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
- auto &prompt_tokens = slot.prompt_tokens;
+ auto & prompt_tokens = slot.prompt_tokens;
// TODO: maybe move branch to outside of this loop in the future
if (slot.state == SLOT_STATE_STARTED) {
@@ -2822,21 +2909,18 @@ struct server_context {
slot.n_prompt_tokens = prompt_tokens.size();
slot.state = SLOT_STATE_PROCESSING_PROMPT;
- SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx,
- slot.params.n_keep, slot.n_prompt_tokens);
+ SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
// print prompt tokens (for debugging)
if (1) {
// first 16 tokens (avoid flooding logs)
for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) {
- SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i],
- common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
} else {
// all
- for (int i = 0; i < (int)prompt_tokens.size(); i++) {
- SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i],
- common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ for (int i = 0; i < (int) prompt_tokens.size(); i++) {
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
}
@@ -2853,15 +2937,13 @@ struct server_context {
if (slot.is_non_causal()) {
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
- send_error(slot, "input is too large to process. increase the physical batch size",
- ERROR_TYPE_SERVER);
+ send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
continue;
}
if (slot.n_prompt_tokens > slot.n_ctx) {
slot.release();
- send_error(slot, "input is larger than the max context size. skipping",
- ERROR_TYPE_SERVER);
+ send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
continue;
}
} else {
@@ -2871,10 +2953,7 @@ struct server_context {
// context shift should be applied only during the generation phase
if (slot.n_prompt_tokens >= slot.n_ctx) {
slot.release();
- send_error(slot,
- "the request exceeds the available context size. try increasing the "
- "context size or enable context shift",
- ERROR_TYPE_INVALID_REQUEST);
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
continue;
}
}
@@ -2888,25 +2967,23 @@ struct server_context {
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
- const int erased_blocks =
- (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
+ const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
- llama_tokens new_tokens(prompt_tokens.begin(),
- prompt_tokens.begin() + slot.params.n_keep);
+ llama_tokens new_tokens(
+ prompt_tokens.begin(),
+ prompt_tokens.begin() + slot.params.n_keep);
- new_tokens.insert(new_tokens.end(),
- prompt_tokens.begin() + slot.params.n_keep +
- erased_blocks * n_block_size,
- prompt_tokens.end());
+ new_tokens.insert(
+ new_tokens.end(),
+ prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
+ prompt_tokens.end());
prompt_tokens = std::move(new_tokens);
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
- SLT_WRN(slot,
- "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n",
- slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
+ SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
@@ -2920,33 +2997,29 @@ struct server_context {
size_t head_c = slot.n_past; // cache
size_t head_p = slot.n_past; // current prompt
- SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n",
- params_base.n_cache_reuse, slot.n_past);
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
- while (head_c < slot.cache_tokens.size() && head_p < prompt_tokens.size()) {
+ while (head_c < slot.cache_tokens.size() &&
+ head_p < prompt_tokens.size()) {
size_t n_match = 0;
while (head_c + n_match < slot.cache_tokens.size() &&
- head_p + n_match < prompt_tokens.size() &&
+ head_p + n_match < prompt_tokens.size() &&
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
n_match++;
}
- if (n_match >= (size_t)params_base.n_cache_reuse) {
- SLT_INF(slot,
- "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> "
- "[%zu, %zu)\n",
- n_match, head_c, head_c + n_match, head_p, head_p + n_match);
- // for (size_t i = head_p; i < head_p + n_match; i++) {
- // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i],
- // common_token_to_piece(ctx, prompt_tokens[i]).c_str());
- // }
+ if (n_match >= (size_t) params_base.n_cache_reuse) {
+ SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
+ //for (size_t i = head_p; i < head_p + n_match; i++) {
+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+ //}
- const int64_t kv_shift = (int64_t)head_p - (int64_t)head_c;
+ const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
- llama_kv_cache_seq_rm(ctx, slot.id, head_p, head_c);
- llama_kv_cache_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
+ llama_kv_self_seq_rm (ctx, slot.id, head_p, head_c);
+ llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@@ -2967,10 +3040,7 @@ struct server_context {
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
// we have to evaluate at least 1 token to generate logits.
- SLT_WRN(slot,
- "need to evaluate at least 1 token to generate logits, n_past = %d, "
- "n_prompt_tokens = %d\n",
- slot.n_past, slot.n_prompt_tokens);
+ SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
slot.n_past--;
}
@@ -2987,9 +3057,9 @@ struct server_context {
}
// keep only the common part
- if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
+ if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) {
// could not partially delete (likely using a non-Transformer model)
- llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
+ llama_kv_self_seq_rm(ctx, slot.id, -1, -1);
// there is no common part left
slot.n_past = 0;
@@ -3003,10 +3073,9 @@ struct server_context {
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// without pooling, we want to output the embeddings for all the tokens in the batch
- const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING &&
- llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
+ const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
- common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, {slot.id}, need_embd);
+ common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -3016,8 +3085,7 @@ struct server_context {
slot.n_past++;
}
- SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n",
- slot.n_past, batch.n_tokens, (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
+ SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
@@ -3036,7 +3104,7 @@ struct server_context {
batch.logits[batch.n_tokens - 1] = true;
slot.n_decoded = 0;
- slot.i_batch = batch.n_tokens - 1;
+ slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
}
@@ -3067,8 +3135,13 @@ struct server_context {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
llama_batch batch_view = {
- n_tokens, batch.token + i, nullptr, batch.pos + i,
- batch.n_seq_id + i, batch.seq_id + i, batch.logits + i,
+ n_tokens,
+ batch.token + i,
+ nullptr,
+ batch.pos + i,
+ batch.n_seq_id + i,
+ batch.seq_id + i,
+ batch.logits + i,
};
const int ret = llama_decode(ctx, batch_view);
@@ -3077,10 +3150,8 @@ struct server_context {
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
- SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i "
- "= %d, n_batch = %d, ret = %d\n",
- i, n_batch, ret);
- for (auto &slot : slots) {
+ SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
+ for (auto & slot : slots) {
slot.release();
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
}
@@ -3091,15 +3162,13 @@ struct server_context {
n_batch /= 2;
i -= n_batch;
- SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing "
- "it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n",
- i, n_batch, ret);
+ SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
continue; // continue loop of n_batch
}
- for (auto &slot : slots) {
- if (slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) {
+ for (auto & slot : slots) {
+ if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
continue; // continue loop of slots
}
@@ -3146,9 +3215,9 @@ struct server_context {
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
completion_token_output result;
- result.tok = id;
+ result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
- result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
+ result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.params.sampling.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
@@ -3165,7 +3234,7 @@ struct server_context {
}
// do speculative decoding
- for (auto &slot : slots) {
+ for (auto & slot : slots) {
if (!slot.is_processing() || !slot.can_speculate()) {
continue;
}
@@ -3188,8 +3257,7 @@ struct server_context {
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
if (n_draft_max < slot.params.speculative.n_min) {
- SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n",
- n_draft_max, slot.params.speculative.n_min);
+ SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
continue;
}
@@ -3197,25 +3265,25 @@ struct server_context {
llama_token id = slot.sampled;
struct common_speculative_params params_spec;
- params_spec.n_draft = n_draft_max;
- params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
- params_spec.p_min = slot.params.speculative.p_min;
+ params_spec.n_draft = n_draft_max;
+ params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
+ params_spec.p_min = slot.params.speculative.p_min;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
// ignore small drafts
- if (slot.params.speculative.n_min > (int)draft.size()) {
- SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
+ if (slot.params.speculative.n_min > (int) draft.size()) {
+ SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
continue;
}
// construct the speculation batch
common_batch_clear(slot.batch_spec);
- common_batch_add(slot.batch_spec, id, slot.n_past, {slot.id}, true);
+ common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
- common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, {slot.id}, true);
+ common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
@@ -3225,21 +3293,20 @@ struct server_context {
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
- slot.n_past += ids.size();
+ slot.n_past += ids.size();
slot.n_decoded += ids.size();
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
- llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
+ llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
- result.tok = ids[i];
- result.text_to_send =
- common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
- result.prob = 1.0f; // set later
+ result.tok = ids[i];
+ result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
+ result.prob = 1.0f; // set later
// TODO: set result.probs
@@ -3253,8 +3320,7 @@ struct server_context {
}
}
- SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int)ids.size() - 1, (int)draft.size(),
- slot.n_past);
+ SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
}
}
@@ -3262,14 +3328,31 @@ struct server_context {
}
json model_meta() const {
- return json{
- {"vocab_type", llama_vocab_type(vocab)}, {"n_vocab", llama_vocab_n_tokens(vocab)},
- {"n_ctx_train", llama_model_n_ctx_train(model)}, {"n_embd", llama_model_n_embd(model)},
- {"n_params", llama_model_n_params(model)}, {"size", llama_model_size(model)},
+ return json {
+ {"vocab_type", llama_vocab_type (vocab)},
+ {"n_vocab", llama_vocab_n_tokens (vocab)},
+ {"n_ctx_train", llama_model_n_ctx_train(model)},
+ {"n_embd", llama_model_n_embd (model)},
+ {"n_params", llama_model_n_params (model)},
+ {"size", llama_model_size (model)},
};
}
};
+std::function shutdown_handler;
+std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
+
+inline void signal_handler(int signal) {
+ if (is_terminating.test_and_set()) {
+ // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
+ // this is for better developer experience, we can remove when the server is stable enough
+ fprintf(stderr, "Received second interrupt, terminating immediately.\n");
+ exit(1);
+ }
+
+ shutdown_handler(signal);
+}
+
static void common_params_handle_model_default(std::string &model, const std::string &model_url, std::string &hf_repo,
std::string &hf_file, const std::string &hf_token) {
if (!hf_repo.empty()) {
@@ -3358,7 +3441,7 @@ static void server_params_parse(json jparams, common_params ¶ms) {
params.lookup_cache_static = json_value(jparams, "lookup_cache_static", default_params.lookup_cache_static);
params.lookup_cache_dynamic = json_value(jparams, "lookup_cache_dynamic", default_params.lookup_cache_dynamic);
params.logits_file = json_value(jparams, "logits_file", default_params.logits_file);
- // params.lora_adapters = json_value(jparams, "lora_adapter", default_params.lora_adapters);
+ //params.lora_adapters = json_value(jparams, "lora_adapter", default_params.lora_adapters);
params.embedding = json_value(jparams, "embedding", default_params.embedding);
params.escape = json_value(jparams, "escape", default_params.escape);
params.cont_batching = json_value(jparams, "cont_batching", default_params.cont_batching);
diff --git a/src/main/cpp/utils.hpp b/src/main/cpp/utils.hpp
index 603424b4..bdc19668 100644
--- a/src/main/cpp/utils.hpp
+++ b/src/main/cpp/utils.hpp
@@ -48,14 +48,13 @@ using json = nlohmann::ordered_json;
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
-template static T json_value(const json &body, const std::string &key, const T &default_value) {
+template static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value
if (body.contains(key) && !body.at(key).is_null()) {
try {
return body.at(key);
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
- LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(),
- json(default_value).type_name());
+ LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name());
return default_value;
}
} else {
@@ -69,9 +68,9 @@ const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "
// tokenizer and input processing utils
//
-static bool json_is_array_of_numbers(const json &data) {
+static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
- for (const auto &e : data) {
+ for (const auto & e : data) {
if (!e.is_number_integer()) {
return false;
}
@@ -82,11 +81,11 @@ static bool json_is_array_of_numbers(const json &data) {
}
// is array having BOTH numbers & strings?
-static bool json_is_array_of_mixed_numbers_strings(const json &data) {
+static bool json_is_array_of_mixed_numbers_strings(const json & data) {
bool seen_string = false;
bool seen_number = false;
if (data.is_array()) {
- for (const auto &e : data) {
+ for (const auto & e : data) {
seen_string |= e.is_string();
seen_number |= e.is_number_integer();
if (seen_number && seen_string) {
@@ -98,14 +97,14 @@ static bool json_is_array_of_mixed_numbers_strings(const json &data) {
}
// get value by path(key1 / key2)
-static json json_get_nested_values(const std::vector &paths, const json &js) {
+static json json_get_nested_values(const std::vector & paths, const json & js) {
json result = json::object();
- for (const std::string &path : paths) {
+ for (const std::string & path : paths) {
json current = js;
const auto keys = string_split(path, /*separator*/ '/');
bool valid_path = true;
- for (const std::string &k : keys) {
+ for (const std::string & k : keys) {
if (valid_path && current.is_object() && current.contains(k)) {
current = current[k];
} else {
@@ -124,15 +123,14 @@ static json json_get_nested_values(const std::vector &paths, const
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
-static llama_tokens tokenize_mixed(const llama_vocab *vocab, const json &json_prompt, bool add_special,
- bool parse_special) {
+static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
- for (const auto &p : json_prompt) {
+ for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get();
@@ -173,8 +171,7 @@ static llama_tokens tokenize_mixed(const llama_vocab *vocab, const json &json_pr
* - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/
-static std::vector tokenize_input_prompts(const llama_vocab *vocab, const json &json_prompt,
- bool add_special, bool parse_special) {
+static std::vector tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
std::vector result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed
@@ -185,20 +182,18 @@ static std::vector tokenize_input_prompts(const llama_vocab *vocab
} else if (json_prompt.is_array()) {
// array of prompts
result.reserve(json_prompt.size());
- for (const auto &p : json_prompt) {
+ for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) {
// array of tokens
result.push_back(p.get());
} else {
- throw std::runtime_error(
- "element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
+ throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
}
}
} else {
- throw std::runtime_error(
- "\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
+ throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
}
if (result.empty()) {
throw std::runtime_error("\"prompt\" must not be empty");
@@ -209,10 +204,9 @@ static std::vector tokenize_input_prompts(const llama_vocab *vocab
// return the last index of character that can form a valid string
// if the last character is potentially cut in half, return the index before the cut
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
-static size_t validate_utf8(const std::string &text) {
+static size_t validate_utf8(const std::string& text) {
size_t len = text.size();
- if (len == 0)
- return 0;
+ if (len == 0) return 0;
// Check the last few bytes to see if a multi-byte character is cut off
for (size_t i = 1; i <= 4 && i <= len; ++i) {
@@ -221,18 +215,15 @@ static size_t validate_utf8(const std::string &text) {
if ((c & 0xE0) == 0xC0) {
// 2-byte character start: 110xxxxx
// Needs at least 2 bytes
- if (i < 2)
- return len - i;
+ if (i < 2) return len - i;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character start: 1110xxxx
// Needs at least 3 bytes
- if (i < 3)
- return len - i;
+ if (i < 3) return len - i;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character start: 11110xxx
// Needs at least 4 bytes
- if (i < 4)
- return len - i;
+ if (i < 4) return len - i;
}
}
@@ -245,7 +236,7 @@ static size_t validate_utf8(const std::string &text) {
//
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
-static llama_tokens format_rerank(const struct llama_vocab *vocab, const llama_tokens &query, const llama_tokens &doc) {
+static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result;
result.reserve(doc.size() + query.size() + 4);
@@ -260,9 +251,17 @@ static llama_tokens format_rerank(const struct llama_vocab *vocab, const llama_t
}
// format infill task
-static llama_tokens format_infill(const llama_vocab *vocab, const json &input_prefix, const json &input_suffix,
- const json &input_extra, const int n_batch, const int n_predict, const int n_ctx,
- const bool spm_infill, const llama_tokens &tokens_prompt) {
+static llama_tokens format_infill(
+ const llama_vocab * vocab,
+ const json & input_prefix,
+ const json & input_suffix,
+ const json & input_extra,
+ const int n_batch,
+ const int n_predict,
+ const int n_ctx,
+ const bool spm_infill,
+ const llama_tokens & tokens_prompt
+ ) {
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
@@ -290,9 +289,9 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
extra_tokens.push_back(llama_vocab_fim_rep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
- for (const auto &chunk : input_extra) {
+ for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
- const std::string text = json_value(chunk, "text", std::string());
+ const std::string text = json_value(chunk, "text", std::string());
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
@@ -302,8 +301,7 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
- static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70,
- 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
+ static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
@@ -322,21 +320,19 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
- const int n_prefix_take = std::min(tokens_prefix.size(), 3 * (n_batch / 4));
- const int n_suffix_take =
- std::min(tokens_suffix.size(), std::max(0, (n_batch / 4) - (2 + tokens_prompt.size())));
+ const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4));
+ const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size())));
- SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take,
- (n_prefix_take + n_suffix_take));
+ SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take));
// fill the rest of the context with extra chunks
- const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch)-2 * n_predict), extra_tokens.size());
+ const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab));
- tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
+ tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab));
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
@@ -346,7 +342,7 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
}
- SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int)extra_tokens.size());
+ SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
@@ -361,13 +357,16 @@ static llama_tokens format_infill(const llama_vocab *vocab, const json &input_pr
// base64 utils (TODO: move to common in the future)
//
-static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- "abcdefghijklmnopqrstuvwxyz"
- "0123456789+/";
+static const std::string base64_chars =
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz"
+ "0123456789+/";
-static inline bool is_base64(uint8_t c) { return (isalnum(c) || (c == '+') || (c == '/')); }
+static inline bool is_base64(uint8_t c) {
+ return (isalnum(c) || (c == '+') || (c == '/'));
+}
-static inline std::vector base64_decode(const std::string &encoded_string) {
+static inline std::vector base64_decode(const std::string & encoded_string) {
int i = 0;
int j = 0;
int in_ = 0;
@@ -380,16 +379,15 @@ static inline std::vector base64_decode(const std::string &encoded_stri
std::vector