Skip to content

Commit

Permalink
access to secrets for azure and openai (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas authored Nov 19, 2024
1 parent aaac286 commit 1f019df
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/core/functions/batch_response_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ PromptDetails CreatePromptDetails(Connection &con, const nlohmann::json prompt_d
}

nlohmann::json Complete(const nlohmann::json &tuples, const std::string &user_prompt, ScalarFunctionType function_type,
const ModelDetails &model_details) {
ModelDetails &model_details) {
nlohmann::json data;
auto tuples_markdown = ConstructMarkdownArrayTuples(tuples);
auto prompt = ScalarPromptTemplate::GetPrompt(user_prompt, tuples_markdown, function_type);
Expand All @@ -98,7 +98,7 @@ nlohmann::json Complete(const nlohmann::json &tuples, const std::string &user_pr
};

nlohmann::json BatchAndComplete(std::vector<nlohmann::json> &tuples, Connection &con, std::string user_prompt,
ScalarFunctionType function_type, const ModelDetails &model_details) {
ScalarFunctionType function_type, ModelDetails &model_details) {
auto llm_template = ScalarPromptTemplate::GetPromptTemplate(function_type);

int num_tokens_meta_and_user_pormpt = 0;
Expand Down
67 changes: 51 additions & 16 deletions src/core/model_manager/model_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include <stdexcept>
#include <nlohmann/json.hpp>
#include <flockmtl/core/module.hpp>

namespace flockmtl {
namespace core {
Expand Down Expand Up @@ -134,9 +135,12 @@ nlohmann::json ModelManager::AwsBedrockCallComplete(const std::string &prompt, c
nlohmann::json ModelManager::OpenAICallComplete(const std::string &prompt, const ModelDetails &model_details,
const bool json_response) {

// Get API key from the environment variable
auto key = openai::OpenAI::get_openai_api_key();
openai::start(key);
auto api_key = model_details.secret;
if (api_key.empty()) {
api_key = openai::OpenAI::get_openai_api_key();
}

openai::start(api_key);

// Create a JSON request payload with the provided parameters
nlohmann::json request_payload = {{"model", model_details.model},
Expand Down Expand Up @@ -186,8 +190,11 @@ nlohmann::json ModelManager::OpenAICallComplete(const std::string &prompt, const

nlohmann::json ModelManager::AzureCallComplete(const std::string &prompt, const ModelDetails &model_details,
const bool json_response) {
// Get API key from the environment variable
auto api_key = AzureModelManager::get_azure_api_key();
auto api_key = model_details.secret;
if (api_key.empty()) {
api_key = AzureModelManager::get_azure_api_key();
}

auto resource_name = AzureModelManager::get_azure_resource_name();
auto api_version = AzureModelManager::get_azure_api_version();

Expand Down Expand Up @@ -236,7 +243,7 @@ nlohmann::json ModelManager::AzureCallComplete(const std::string &prompt, const
return content_str;
}

nlohmann::json ModelManager::CallComplete(const std::string &prompt, const ModelDetails &model_details,
nlohmann::json ModelManager::CallComplete(const std::string &prompt, ModelDetails &model_details,
const bool json_response) {

// Check if the provider is in the list of supported provider
Expand Down Expand Up @@ -290,9 +297,12 @@ nlohmann::json ModelManager::AwsBedrockCallEmbedding(const std::vector<string> &
}

nlohmann::json ModelManager::OpenAICallEmbedding(const std::vector<string> &inputs, const ModelDetails &model_details) {
// Get API key from the environment variable
auto key = openai::OpenAI::get_openai_api_key();
openai::start(key);
auto api_key = model_details.secret;
if (api_key.empty()) {
api_key = openai::OpenAI::get_openai_api_key();
}

openai::start(api_key);

// Create a JSON request payload with the provided parameters
nlohmann::json request_payload = {
Expand All @@ -318,8 +328,10 @@ nlohmann::json ModelManager::OpenAICallEmbedding(const std::vector<string> &inpu
}

nlohmann::json ModelManager::AzureCallEmbedding(const std::vector<string> &inputs, const ModelDetails &model_details) {
// Get API key from the environment variable
auto api_key = AzureModelManager::get_azure_api_key();
auto api_key = model_details.secret;
if (api_key.empty()) {
api_key = AzureModelManager::get_azure_api_key();
}
auto resource_name = AzureModelManager::get_azure_resource_name();
auto api_version = AzureModelManager::get_azure_api_version();

Expand Down Expand Up @@ -350,7 +362,7 @@ nlohmann::json ModelManager::AzureCallEmbedding(const std::vector<string> &input
return embeddings;
}

nlohmann::json ModelManager::CallEmbedding(const std::vector<string> &inputs, const ModelDetails &model_details) {
nlohmann::json ModelManager::CallEmbedding(const std::vector<string> &inputs, ModelDetails &model_details) {

// Check if the provider is in the list of supported provider
auto provider = GetProviderType(model_details.provider_name);
Expand All @@ -375,10 +387,32 @@ nlohmann::json ModelManager::CallEmbedding(const std::vector<string> &inputs, co
return result.second;
}

std::pair<bool, nlohmann::json> ModelManager::CallCompleteProvider(const std::string &prompt,
const ModelDetails &model_details,
const bool json_response) {
std::string ModelManager::GetProviderSecret(const SupportedProviders &provider) {

auto provider_name = GetProviderName(provider);

if (provider_name.empty()) {
throw std::runtime_error("Provider not found");
}

auto query = "SELECT secret FROM "
"flockmtl_config.FLOCKMTL_SECRET_INTERNAL_TABLE "
"WHERE provider = '" +
provider_name + "'";

auto query_result = CoreModule::GetConnection().Query(query);

if (query_result->RowCount() == 0) {
return "";
}

return query_result->GetValue(0, 0).ToString();
}

std::pair<bool, nlohmann::json>
ModelManager::CallCompleteProvider(const std::string &prompt, ModelDetails &model_details, const bool json_response) {
auto provider = GetProviderType(model_details.provider_name);
model_details.secret = GetProviderSecret(provider);
switch (provider) {
case FLOCKMTL_OPENAI:
return {true, OpenAICallComplete(prompt, model_details, json_response)};
Expand All @@ -394,8 +428,9 @@ std::pair<bool, nlohmann::json> ModelManager::CallCompleteProvider(const std::st
}

std::pair<bool, nlohmann::json> ModelManager::CallEmbeddingProvider(const std::vector<std::string> &inputs,
const ModelDetails &model_details) {
ModelDetails &model_details) {
auto provider = GetProviderType(model_details.provider_name);
model_details.secret = GetProviderSecret(provider);
switch (provider) {
case FLOCKMTL_OPENAI:
return {true, OpenAICallEmbedding(inputs, model_details)};
Expand Down
11 changes: 7 additions & 4 deletions src/core/parser/query/secret_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ void SecretParser::ParseCreateSecret(Tokenizer &tokenizer, std::unique_ptr<Query
if (token.type != TokenType::KEYWORD || (value != "OPENAI" && value != "AZURE")) {
throw std::runtime_error("Expected 'OPENAI' keyword after 'SECRET'.");
}
auto provider = value;

auto provider = StringUtil::Lower(value);

token = tokenizer.NextToken();
if (token.type != TokenType::SYMBOL || token.value != "=") {
Expand Down Expand Up @@ -80,7 +81,8 @@ void SecretParser::ParseDeleteSecret(Tokenizer &tokenizer, std::unique_ptr<Query
if (token.type != TokenType::KEYWORD || (value != "OPENAI" && value != "AZURE")) {
throw std::runtime_error("Expected 'OPENAI' keyword after 'SECRET'.");
}
auto provider = value;

auto provider = StringUtil::Lower(value);

token = tokenizer.NextToken();
if (token.type == TokenType::SYMBOL || token.value == ";") {
Expand All @@ -103,7 +105,8 @@ void SecretParser::ParseUpdateSecret(Tokenizer &tokenizer, std::unique_ptr<Query
if (token.type != TokenType::KEYWORD || (value != "OPENAI" && value != "AZURE")) {
throw std::runtime_error("Expected 'OPENAI' keyword after 'SECRET'.");
}
auto provider = value;

auto provider = StringUtil::Lower(value);

token = tokenizer.NextToken();
if (token.type != TokenType::SYMBOL || token.value != "=") {
Expand Down Expand Up @@ -143,7 +146,7 @@ void SecretParser::ParseGetSecret(Tokenizer &tokenizer, std::unique_ptr<QuerySta
if (token.type != TokenType::KEYWORD || (value != "OPENAI" && value != "AZURE")) {
throw std::runtime_error("Expected 'OPENAI' or 'AZURE' keyword after 'SECRET'.");
}
auto provider = value;
auto provider = StringUtil::Lower(value);

token = tokenizer.NextToken();
if (token.type == TokenType::SYMBOL || token.value == ";") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ std::string ConstructMarkdownArrayTuples(const nlohmann::json &tuples);
PromptDetails CreatePromptDetails(Connection &con, const nlohmann::json prompt_details_json);

nlohmann::json Complete(const nlohmann::json &tuples, const std::string &user_prompt, ScalarFunctionType function_type,
const ModelDetails &model_details);
ModelDetails &model_details);

nlohmann::json BatchAndComplete(std::vector<nlohmann::json> &tuples, Connection &con, std::string user_prompt_name,
ScalarFunctionType function_type, const ModelDetails &model_details);
ScalarFunctionType function_type, ModelDetails &model_details);

} // namespace core
} // namespace flockmtl
14 changes: 9 additions & 5 deletions src/include/flockmtl/core/model_manager/model_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <tuple>
#include <vector>
#include <string>
#include <flockmtl/core/model_manager/supported_providers.hpp>

namespace flockmtl {
namespace core {
Expand All @@ -24,21 +25,24 @@ struct ModelDetails {
int context_window;
int max_output_tokens;
float temperature;
std::string secret;
};

struct ModelManager {
public:
static ModelDetails CreateModelDetails(Connection &con, const nlohmann::json &model_json);

static nlohmann::json CallComplete(const std::string &prompt, const ModelDetails &model_details,
static nlohmann::json CallComplete(const std::string &prompt, ModelDetails &model_details,
const bool json_response = true);

static nlohmann::json CallEmbedding(const std::vector<string> &inputs, const ModelDetails &model_details);
static nlohmann::json CallEmbedding(const std::vector<string> &inputs, ModelDetails &model_details);

private:
static std::tuple<std::string, int32_t, int32_t> GetQueriedModel(Connection &con, const std::string &model_name,
const std::string &provider_name);

static std::string GetProviderSecret(const SupportedProviders &provider);

static nlohmann::json OpenAICallComplete(const std::string &prompt, const ModelDetails &model_details,
const bool json_response);

Expand All @@ -59,11 +63,11 @@ struct ModelManager {

static nlohmann::json AwsBedrockCallEmbedding(const std::vector<string> &inputs, const ModelDetails &model_details);

static std::pair<bool, nlohmann::json>
CallCompleteProvider(const std::string &prompt, const ModelDetails &model_details, const bool json_response);
static std::pair<bool, nlohmann::json> CallCompleteProvider(const std::string &prompt, ModelDetails &model_details,
const bool json_response);

static std::pair<bool, nlohmann::json> CallEmbeddingProvider(const std::vector<string> &inputs,
const ModelDetails &model_details);
ModelDetails &model_details);
};

} // namespace core
Expand Down
15 changes: 15 additions & 0 deletions src/include/flockmtl/core/model_manager/supported_providers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,18 @@ inline SupportedProviders GetProviderType(std::string provider) {

return FLOCKMTL_UNSUPPORTED_PROVIDER;
}

inline std::string GetProviderName(SupportedProviders provider) {
switch (provider) {
case FLOCKMTL_OPENAI:
return OPENAI;
case FLOCKMTL_AZURE:
return AZURE;
case FLOCKMTL_OLLAMA:
return OLLAMA;
case FLOCKMTL_AWS_BEDROCK:
return BEDROCK;
default:
return "";
}
}

0 comments on commit 1f019df

Please sign in to comment.