From 1f019dff891e8c9ea4e474fe58d40179b5f7d2dd Mon Sep 17 00:00:00 2001 From: Anas DORBANI <95044293+dorbanianas@users.noreply.github.com> Date: Tue, 19 Nov 2024 03:52:47 +0100 Subject: [PATCH] access to secrets for azure and openai (#66) --- src/core/functions/batch_response_builder.cpp | 4 +- src/core/model_manager/model_manager.cpp | 67 ++++++++++++++----- src/core/parser/query/secret_parser.cpp | 11 +-- .../core/functions/batch_response_builder.hpp | 4 +- .../core/model_manager/model_manager.hpp | 14 ++-- .../model_manager/supported_providers.hpp | 15 +++++ 6 files changed, 86 insertions(+), 29 deletions(-) diff --git a/src/core/functions/batch_response_builder.cpp b/src/core/functions/batch_response_builder.cpp index 40de512a..4f0535d8 100644 --- a/src/core/functions/batch_response_builder.cpp +++ b/src/core/functions/batch_response_builder.cpp @@ -89,7 +89,7 @@ PromptDetails CreatePromptDetails(Connection &con, const nlohmann::json prompt_d } nlohmann::json Complete(const nlohmann::json &tuples, const std::string &user_prompt, ScalarFunctionType function_type, - const ModelDetails &model_details) { + ModelDetails &model_details) { nlohmann::json data; auto tuples_markdown = ConstructMarkdownArrayTuples(tuples); auto prompt = ScalarPromptTemplate::GetPrompt(user_prompt, tuples_markdown, function_type); @@ -98,7 +98,7 @@ nlohmann::json Complete(const nlohmann::json &tuples, const std::string &user_pr }; nlohmann::json BatchAndComplete(std::vector &tuples, Connection &con, std::string user_prompt, - ScalarFunctionType function_type, const ModelDetails &model_details) { + ScalarFunctionType function_type, ModelDetails &model_details) { auto llm_template = ScalarPromptTemplate::GetPromptTemplate(function_type); int num_tokens_meta_and_user_pormpt = 0; diff --git a/src/core/model_manager/model_manager.cpp b/src/core/model_manager/model_manager.cpp index 102b4df2..f51bb515 100644 --- a/src/core/model_manager/model_manager.cpp +++ b/src/core/model_manager/model_manager.cpp @@ -12,6 +12,7 @@ #include #include #include +#include namespace flockmtl { namespace core { @@ -134,9 +135,12 @@ nlohmann::json ModelManager::AwsBedrockCallComplete(const std::string &prompt, c nlohmann::json ModelManager::OpenAICallComplete(const std::string &prompt, const ModelDetails &model_details, const bool json_response) { - // Get API key from the environment variable - auto key = openai::OpenAI::get_openai_api_key(); - openai::start(key); + auto api_key = model_details.secret; + if (api_key.empty()) { + api_key = openai::OpenAI::get_openai_api_key(); + } + + openai::start(api_key); // Create a JSON request payload with the provided parameters nlohmann::json request_payload = {{"model", model_details.model}, @@ -186,8 +190,11 @@ nlohmann::json ModelManager::OpenAICallComplete(const std::string &prompt, const nlohmann::json ModelManager::AzureCallComplete(const std::string &prompt, const ModelDetails &model_details, const bool json_response) { - // Get API key from the environment variable - auto api_key = AzureModelManager::get_azure_api_key(); + auto api_key = model_details.secret; + if (api_key.empty()) { + api_key = AzureModelManager::get_azure_api_key(); + } + auto resource_name = AzureModelManager::get_azure_resource_name(); auto api_version = AzureModelManager::get_azure_api_version(); @@ -236,7 +243,7 @@ nlohmann::json ModelManager::AzureCallComplete(const std::string &prompt, const return content_str; } -nlohmann::json ModelManager::CallComplete(const std::string &prompt, const ModelDetails &model_details, +nlohmann::json ModelManager::CallComplete(const std::string &prompt, ModelDetails &model_details, const bool json_response) { // Check if the provider is in the list of supported provider @@ -290,9 +297,12 @@ nlohmann::json ModelManager::AwsBedrockCallEmbedding(const std::vector & } nlohmann::json ModelManager::OpenAICallEmbedding(const std::vector &inputs, const ModelDetails &model_details) { - // Get API key from the environment variable - auto key = openai::OpenAI::get_openai_api_key(); - openai::start(key); + auto api_key = model_details.secret; + if (api_key.empty()) { + api_key = openai::OpenAI::get_openai_api_key(); + } + + openai::start(api_key); // Create a JSON request payload with the provided parameters nlohmann::json request_payload = { @@ -318,8 +328,10 @@ nlohmann::json ModelManager::OpenAICallEmbedding(const std::vector &inpu } nlohmann::json ModelManager::AzureCallEmbedding(const std::vector &inputs, const ModelDetails &model_details) { - // Get API key from the environment variable - auto api_key = AzureModelManager::get_azure_api_key(); + auto api_key = model_details.secret; + if (api_key.empty()) { + api_key = AzureModelManager::get_azure_api_key(); + } auto resource_name = AzureModelManager::get_azure_resource_name(); auto api_version = AzureModelManager::get_azure_api_version(); @@ -350,7 +362,7 @@ nlohmann::json ModelManager::AzureCallEmbedding(const std::vector &input return embeddings; } -nlohmann::json ModelManager::CallEmbedding(const std::vector &inputs, const ModelDetails &model_details) { +nlohmann::json ModelManager::CallEmbedding(const std::vector &inputs, ModelDetails &model_details) { // Check if the provider is in the list of supported provider auto provider = GetProviderType(model_details.provider_name); @@ -375,10 +387,32 @@ nlohmann::json ModelManager::CallEmbedding(const std::vector &inputs, co return result.second; } -std::pair ModelManager::CallCompleteProvider(const std::string &prompt, - const ModelDetails &model_details, - const bool json_response) { +std::string ModelManager::GetProviderSecret(const SupportedProviders &provider) { + + auto provider_name = GetProviderName(provider); + + if (provider_name.empty()) { + throw std::runtime_error("Provider not found"); + } + + auto query = "SELECT secret FROM " + "flockmtl_config.FLOCKMTL_SECRET_INTERNAL_TABLE " + "WHERE provider = '" + + provider_name + "'"; + + auto query_result = CoreModule::GetConnection().Query(query); + + if (query_result->RowCount() == 0) { + return ""; + } + + return query_result->GetValue(0, 0).ToString(); +} + +std::pair +ModelManager::CallCompleteProvider(const std::string &prompt, ModelDetails &model_details, const bool json_response) { auto provider = GetProviderType(model_details.provider_name); + model_details.secret = GetProviderSecret(provider); switch (provider) { case FLOCKMTL_OPENAI: return {true, OpenAICallComplete(prompt, model_details, json_response)}; @@ -394,8 +428,9 @@ std::pair ModelManager::CallCompleteProvider(const std::st } std::pair ModelManager::CallEmbeddingProvider(const std::vector &inputs, - const ModelDetails &model_details) { + ModelDetails &model_details) { auto provider = GetProviderType(model_details.provider_name); + model_details.secret = GetProviderSecret(provider); switch (provider) { case FLOCKMTL_OPENAI: return {true, OpenAICallEmbedding(inputs, model_details)}; diff --git a/src/core/parser/query/secret_parser.cpp b/src/core/parser/query/secret_parser.cpp index 218a18c8..b1b38194 100644 --- a/src/core/parser/query/secret_parser.cpp +++ b/src/core/parser/query/secret_parser.cpp @@ -44,7 +44,8 @@ void SecretParser::ParseCreateSecret(Tokenizer &tokenizer, std::unique_ptr &tuples, Connection &con, std::string user_prompt_name, - ScalarFunctionType function_type, const ModelDetails &model_details); + ScalarFunctionType function_type, ModelDetails &model_details); } // namespace core } // namespace flockmtl \ No newline at end of file diff --git a/src/include/flockmtl/core/model_manager/model_manager.hpp b/src/include/flockmtl/core/model_manager/model_manager.hpp index 42270560..3e5130a7 100644 --- a/src/include/flockmtl/core/model_manager/model_manager.hpp +++ b/src/include/flockmtl/core/model_manager/model_manager.hpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace flockmtl { namespace core { @@ -24,21 +25,24 @@ struct ModelDetails { int context_window; int max_output_tokens; float temperature; + std::string secret; }; struct ModelManager { public: static ModelDetails CreateModelDetails(Connection &con, const nlohmann::json &model_json); - static nlohmann::json CallComplete(const std::string &prompt, const ModelDetails &model_details, + static nlohmann::json CallComplete(const std::string &prompt, ModelDetails &model_details, const bool json_response = true); - static nlohmann::json CallEmbedding(const std::vector &inputs, const ModelDetails &model_details); + static nlohmann::json CallEmbedding(const std::vector &inputs, ModelDetails &model_details); private: static std::tuple GetQueriedModel(Connection &con, const std::string &model_name, const std::string &provider_name); + static std::string GetProviderSecret(const SupportedProviders &provider); + static nlohmann::json OpenAICallComplete(const std::string &prompt, const ModelDetails &model_details, const bool json_response); @@ -59,11 +63,11 @@ struct ModelManager { static nlohmann::json AwsBedrockCallEmbedding(const std::vector &inputs, const ModelDetails &model_details); - static std::pair - CallCompleteProvider(const std::string &prompt, const ModelDetails &model_details, const bool json_response); + static std::pair CallCompleteProvider(const std::string &prompt, ModelDetails &model_details, + const bool json_response); static std::pair CallEmbeddingProvider(const std::vector &inputs, - const ModelDetails &model_details); + ModelDetails &model_details); }; } // namespace core diff --git a/src/include/flockmtl/core/model_manager/supported_providers.hpp b/src/include/flockmtl/core/model_manager/supported_providers.hpp index 1dbd0c2c..4a0a22d0 100644 --- a/src/include/flockmtl/core/model_manager/supported_providers.hpp +++ b/src/include/flockmtl/core/model_manager/supported_providers.hpp @@ -30,3 +30,18 @@ inline SupportedProviders GetProviderType(std::string provider) { return FLOCKMTL_UNSUPPORTED_PROVIDER; } + +inline std::string GetProviderName(SupportedProviders provider) { + switch (provider) { + case FLOCKMTL_OPENAI: + return OPENAI; + case FLOCKMTL_AZURE: + return AZURE; + case FLOCKMTL_OLLAMA: + return OLLAMA; + case FLOCKMTL_AWS_BEDROCK: + return BEDROCK; + default: + return ""; + } +}