From 097605b9b33310cf4249f8a30c62ed294b0b97c4 Mon Sep 17 00:00:00 2001 From: Anas Dorbani Date: Sun, 3 Nov 2024 22:04:08 +0100 Subject: [PATCH] add the batching to the llm_embedding --- src/core/functions/scalar/llm_embedding.cpp | 8 +++--- src/core/model_manager/model_manager.cpp | 28 +++++++++++-------- .../core/model_manager/model_manager.hpp | 6 ++-- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/core/functions/scalar/llm_embedding.cpp b/src/core/functions/scalar/llm_embedding.cpp index d45ef792..b010d201 100644 --- a/src/core/functions/scalar/llm_embedding.cpp +++ b/src/core/functions/scalar/llm_embedding.cpp @@ -19,17 +19,17 @@ static void LlmEmbeddingScalarFunction(DataChunk &args, ExpressionState &state, auto model_details_json = CoreScalarParsers::Struct2Json(args.data[1], 1)[0]; auto model_details = ModelManager::CreateModelDetails(con, model_details_json); - auto embeddings = nlohmann::json::array(); + vector prepared_inputs; for (auto &row : inputs) { std::string concat_input; for (auto &item : row.items()) { concat_input += item.value().get() + " "; } - - auto element_embedding = ModelManager::CallEmbedding(concat_input, model_details); - embeddings.push_back(element_embedding); + prepared_inputs.push_back(concat_input); } + auto embeddings = ModelManager::CallEmbedding(prepared_inputs, model_details); + for (size_t index = 0; index < embeddings.size(); index++) { vector embedding; for (auto &value : embeddings[index]) { diff --git a/src/core/model_manager/model_manager.cpp b/src/core/model_manager/model_manager.cpp index c341a1e5..26fa86cd 100644 --- a/src/core/model_manager/model_manager.cpp +++ b/src/core/model_manager/model_manager.cpp @@ -208,7 +208,7 @@ nlohmann::json ModelManager::CallComplete(const std::string &prompt, const Model } } -nlohmann::json ModelManager::OpenAICallEmbedding(const std::string &input, const ModelDetails &model_details) { +nlohmann::json ModelManager::OpenAICallEmbedding(const vector &inputs, const ModelDetails &model_details) { // Get API key from the environment variable auto key = openai::OpenAI::get_openai_api_key(); openai::start(key); @@ -216,7 +216,7 @@ nlohmann::json ModelManager::OpenAICallEmbedding(const std::string &input, const // Create a JSON request payload with the provided parameters nlohmann::json request_payload = { {"model", model_details.model}, - {"input", input}, + {"input", inputs}, }; // Make a request to the OpenAI API @@ -230,12 +230,15 @@ nlohmann::json ModelManager::OpenAICallEmbedding(const std::string &input, const // Add error handling code here } - auto embedding = completion["data"][0]["embedding"]; + auto embeddings = nlohmann::json::array(); + for (auto &item : completion["data"]) { + embeddings.push_back(item["embedding"]); + } - return embedding; + return embeddings; } -nlohmann::json ModelManager::AzureCallEmbedding(const std::string &input, const ModelDetails &model_details) { +nlohmann::json ModelManager::AzureCallEmbedding(const vector &inputs, const ModelDetails &model_details) { // Get API key from the environment variable auto api_key = AzureModelManager::get_azure_api_key(); auto resource_name = AzureModelManager::get_azure_resource_name(); @@ -247,7 +250,7 @@ nlohmann::json ModelManager::AzureCallEmbedding(const std::string &input, const // Create a JSON request payload with the provided parameters nlohmann::json request_payload = { {"model", model_details.model}, - {"input", input}, + {"input", inputs}, }; // Make a request to the Azure API @@ -261,12 +264,15 @@ nlohmann::json ModelManager::AzureCallEmbedding(const std::string &input, const // Add error handling code here } - auto embedding = completion["data"][0]["embedding"]; + auto embeddings = nlohmann::json::array(); + for (auto &item : completion["data"]) { + embeddings.push_back(item["embedding"]); + } - return embedding; + return embeddings; } -nlohmann::json ModelManager::CallEmbedding(const std::string &input, const ModelDetails &model_details) { +nlohmann::json ModelManager::CallEmbedding(const vector &inputs, const ModelDetails &model_details) { // Check if the provided model is in the list of supported models if (supported_embedding_models.find(model_details.model) == supported_embedding_models.end()) { @@ -286,9 +292,9 @@ nlohmann::json ModelManager::CallEmbedding(const std::string &input, const Model if (model_details.provider_name == "openai" || model_details.provider_name == "default" || model_details.provider_name.empty()) { - return OpenAICallEmbedding(input, model_details); + return OpenAICallEmbedding(inputs, model_details); } else { - return AzureCallEmbedding(input, model_details); + return AzureCallEmbedding(inputs, model_details); } } diff --git a/src/include/flockmtl/core/model_manager/model_manager.hpp b/src/include/flockmtl/core/model_manager/model_manager.hpp index f7907551..144b3fac 100644 --- a/src/include/flockmtl/core/model_manager/model_manager.hpp +++ b/src/include/flockmtl/core/model_manager/model_manager.hpp @@ -22,7 +22,7 @@ struct ModelManager { static nlohmann::json CallComplete(const std::string &prompt, const ModelDetails &model_details, const bool json_response = true); - static nlohmann::json CallEmbedding(const std::string &input, const ModelDetails &model_details); + static nlohmann::json CallEmbedding(const vector &inputs, const ModelDetails &model_details); private: static std::pair GetQueriedModel(Connection &con, const std::string &model_name, @@ -34,9 +34,9 @@ struct ModelManager { static nlohmann::json AzureCallComplete(const std::string &prompt, const ModelDetails &model_details, const bool json_response); - static nlohmann::json OpenAICallEmbedding(const std::string &input, const ModelDetails &model_details); + static nlohmann::json OpenAICallEmbedding(const vector &inputs, const ModelDetails &model_details); - static nlohmann::json AzureCallEmbedding(const std::string &input, const ModelDetails &model_details); + static nlohmann::json AzureCallEmbedding(const vector &inputs, const ModelDetails &model_details); }; } // namespace core