Skip to content

Commit

Permalink
Enhanced String Manipulation with duckdb_fmt::format (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas authored Nov 25, 2024
1 parent 745aa84 commit fa4adbc
Show file tree
Hide file tree
Showing 17 changed files with 296 additions and 259 deletions.
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ AllowShortFunctionsOnASingleLine: All
CompactNamespaces: false
---
DerivePointerAlignment: false
PointerAlignment: Right
PointerAlignment: Left
AlignConsecutiveMacros: true
AlignTrailingComments: true
AllowAllArgumentsOnNextLine: true
Expand Down
172 changes: 83 additions & 89 deletions src/core/config/config.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
#include "flockmtl/core/config/config.hpp"

#include <iostream>

namespace flockmtl {
namespace core {

std::string Config::get_schema_name() {
return "flockmtl_config";
}
std::string Config::get_schema_name() { return "flockmtl_config"; }

std::string Config::get_default_models_table_name() {
return "FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE";
}
std::string Config::get_default_models_table_name() { return "FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE"; }

std::string Config::get_user_defined_models_table_name() {
return "FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE";
}
std::string Config::get_user_defined_models_table_name() { return "FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE"; }

std::string Config::get_prompts_table_name() {
return "FLOCKMTL_PROMPT_INTERNAL_TABLE";
}
std::string Config::get_prompts_table_name() { return "FLOCKMTL_PROMPT_INTERNAL_TABLE"; }

std::string Config::get_secrets_table_name() {
return "FLOCKMTL_SECRET_INTERNAL_TABLE";
}
std::string Config::get_secrets_table_name() { return "FLOCKMTL_SECRET_INTERNAL_TABLE"; }

void Config::Configure(duckdb::DatabaseInstance &db) {
void Config::Configure(duckdb::DatabaseInstance& db) {
std::string schema = Config::get_schema_name();
duckdb::Connection con(db);
con.BeginTransaction();
Expand All @@ -38,107 +26,113 @@ void Config::Configure(duckdb::DatabaseInstance &db) {
con.Commit();
}

void Config::ConfigSchema(duckdb::Connection &con, std::string &schema_name) {
void Config::ConfigSchema(duckdb::Connection& con, std::string& schema_name) {

// Ensure schema exists
auto result = con.Query("SELECT * FROM information_schema.schemata WHERE schema_name = '" + schema_name + "';");
// Check if schema exists using fmt
auto result = con.Query(duckdb_fmt::format(" SELECT * "
" FROM information_schema.schemata "
" WHERE schema_name = '{}'; ",
schema_name));
if (result->RowCount() == 0) {
con.Query("CREATE SCHEMA " + schema_name + ";");
con.Query(duckdb_fmt::format("CREATE SCHEMA {};", schema_name));
}
}

void Config::setup_default_models_config(duckdb::Connection &con, std::string &schema_name) {
void Config::setup_default_models_config(duckdb::Connection& con, std::string& schema_name) {
const std::string table_name = Config::get_default_models_table_name();
// Ensure schema exists
auto result = con.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = '" + schema_name +
"' AND table_name = '" + table_name + "';");
auto result = con.Query(duckdb_fmt::format(" SELECT table_name "
" FROM information_schema.tables "
" WHERE table_schema = '{}' "
" AND table_name = '{}'; ",
schema_name, table_name));
if (result->RowCount() == 0) {
con.Query("LOAD JSON;"
"CREATE TABLE " +
schema_name + "." + table_name +
" ("
"model_name VARCHAR NOT NULL PRIMARY KEY, "
"model VARCHAR NOT NULL, "
"provider_name VARCHAR NOT NULL, "
"model_args JSON NOT NULL"
");");

con.Query(
"INSERT INTO " + schema_name + "." + table_name +
" (model_name, model, provider_name, model_args) VALUES "
"('default', 'gpt-4o-mini', 'openai', '{\"context_window\": 128000, \"max_output_tokens\": 16384}'),"
"('gpt-4o-mini', 'gpt-4o-mini', 'openai', '{\"context_window\": 128000, \"max_output_tokens\": 16384}'),"
"('gpt-4o', 'gpt-4o', 'openai', '{\"context_window\": 128000, \"max_output_tokens\": 16384}'),"
"('text-embedding-3-large', 'text-embedding-3-large', 'openai', "
"'{\"context_window\": " +
std::to_string(Config::default_context_window) +
", "
"\"max_output_tokens\": " +
std::to_string(Config::default_max_output_tokens) +
"}'),"
"('text-embedding-3-small', 'text-embedding-3-small', 'openai', "
"'{\"context_window\": " +
std::to_string(Config::default_context_window) +
", "
"\"max_output_tokens\": " +
std::to_string(Config::default_max_output_tokens) + "}');");
con.Query(duckdb_fmt::format(" LOAD JSON; "
" CREATE TABLE {}.{} ( "
" model_name VARCHAR NOT NULL PRIMARY KEY, "
" model VARCHAR NOT NULL, "
" provider_name VARCHAR NOT NULL, "
" model_args JSON NOT NULL "
" ); ",
schema_name, table_name));

con.Query(duckdb_fmt::format(
" INSERT INTO {}.{} (model_name, model, provider_name, model_args) "
" VALUES "
" ('default', 'gpt-4o-mini', 'openai', '{{\"context_window\": 128000, \"max_output_tokens\": 16384}}'),"
" ('gpt-4o-mini', 'gpt-4o-mini', 'openai', '{{\"context_window\": 128000, \"max_output_tokens\": 16384}}'),"
" ('gpt-4o', 'gpt-4o', 'openai', '{{\"context_window\": 128000, \"max_output_tokens\": 16384}}'),"
" ('text-embedding-3-large', 'text-embedding-3-large', 'openai', "
" '{{\"context_window\": {}, \"max_output_tokens\": {}}}',"
" ('text-embedding-3-small', 'text-embedding-3-small', 'openai', "
" '{{\"context_window\": {}, \"max_output_tokens\": {}}}')",
schema_name, table_name, Config::default_context_window, Config::default_max_output_tokens,
Config::default_context_window, Config::default_max_output_tokens));
}
}

void Config::setup_user_defined_models_config(duckdb::Connection &con, std::string &schema_name) {
void Config::setup_user_defined_models_config(duckdb::Connection& con, std::string& schema_name) {
const std::string table_name = Config::get_user_defined_models_table_name();
// Ensure schema exists
auto result = con.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = '" + schema_name +
"' AND table_name = '" + table_name + "';");
auto result = con.Query(duckdb_fmt::format(" SELECT table_name "
" FROM information_schema.tables "
" WHERE table_schema = '{}' "
" AND table_name = '{}'; ",
schema_name, table_name));
if (result->RowCount() == 0) {
con.Query("LOAD JSON;"
"CREATE TABLE " +
schema_name + "." + table_name +
" ("
"model_name VARCHAR NOT NULL PRIMARY KEY,"
"model VARCHAR,"
"provider_name VARCHAR NOT NULL,"
"model_args JSON NOT NULL"
");");
con.Query(duckdb_fmt::format(" LOAD JSON; "
" CREATE TABLE {}.{} ( "
" model_name VARCHAR NOT NULL PRIMARY KEY, "
" model VARCHAR NOT NULL, "
" provider_name VARCHAR NOT NULL, "
" model_args JSON NOT NULL"
" ); ",
schema_name, table_name));
}
}

void Config::ConfigModelTable(duckdb::Connection &con, std::string &schema_name) {
void Config::ConfigModelTable(duckdb::Connection& con, std::string& schema_name) {
setup_default_models_config(con, schema_name);
setup_user_defined_models_config(con, schema_name);
}

void Config::ConfigPromptTable(duckdb::Connection &con, std::string &schema_name) {
void Config::ConfigPromptTable(duckdb::Connection& con, std::string& schema_name) {
const std::string table_name = "FLOCKMTL_PROMPT_INTERNAL_TABLE";

auto result = con.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = '" + schema_name +
"' AND table_name = '" + table_name + "';");
auto result = con.Query(duckdb_fmt::format(" SELECT table_name "
" FROM information_schema.tables "
" WHERE table_schema = '{}' "
" AND table_name = '{}'; ",
schema_name, table_name));
if (result->RowCount() == 0) {
con.Query("CREATE TABLE " + schema_name + "." + table_name +
" ("
"prompt_name VARCHAR NOT NULL,"
"prompt VARCHAR NOT NULL,"
"update_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,"
"version INT DEFAULT 1,"
"PRIMARY KEY (prompt_name, version)"
");");

con.Query("INSERT INTO " + schema_name + "." + table_name +
" (prompt_name ,prompt) VALUES ('hello-world', 'Tell me hello world');");
con.Query(duckdb_fmt::format(" CREATE TABLE {}.{} ( "
" prompt_name VARCHAR NOT NULL, "
" prompt VARCHAR NOT NULL, "
" update_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, "
" version INT DEFAULT 1, "
" PRIMARY KEY (prompt_name, version) "
" ); ",
schema_name, table_name));
con.Query(duckdb_fmt::format(" INSERT INTO {}.{} (prompt_name, prompt) "
" VALUES ('hello-world', 'Tell me hello world'); ",
schema_name, table_name));
}
}

void Config::ConfigSecretTable(duckdb::Connection &con, std::string &schema_name) {
void Config::ConfigSecretTable(duckdb::Connection& con, std::string& schema_name) {
const std::string table_name = get_secrets_table_name();

auto result = con.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = '" + schema_name +
"' AND table_name = '" + table_name + "';");
auto result = con.Query(duckdb_fmt::format(" SELECT table_name "
" FROM information_schema.tables "
" WHERE table_schema = '{}' "
" AND table_name = '{}'; ",
schema_name, table_name));
if (result->RowCount() == 0) {
con.Query("CREATE TABLE " + schema_name + "." + table_name +
" ("
"provider VARCHAR NOT NULL PRIMARY KEY,"
"secret VARCHAR NOT NULL"
");");
con.Query(duckdb_fmt::format(" CREATE TABLE {}.{} ( "
" provider VARCHAR NOT NULL PRIMARY KEY, "
" secret VARCHAR NOT NULL "
" ); ",
schema_name, table_name));
}
}

Expand Down
72 changes: 40 additions & 32 deletions src/custom_parser/query/model_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace flockmtl {

void ModelParser::Parse(const std::string &query, std::unique_ptr<QueryStatement> &statement) {
void ModelParser::Parse(const std::string& query, std::unique_ptr<QueryStatement>& statement) {
Tokenizer tokenizer(query);
Token token = tokenizer.NextToken();
std::string value = StringUtil::Upper(token.value);
Expand All @@ -29,7 +29,7 @@ void ModelParser::Parse(const std::string &query, std::unique_ptr<QueryStatement
}
}

void ModelParser::ParseCreateModel(Tokenizer &tokenizer, std::unique_ptr<QueryStatement> &statement) {
void ModelParser::ParseCreateModel(Tokenizer& tokenizer, std::unique_ptr<QueryStatement>& statement) {
Token token = tokenizer.NextToken();
std::string value = StringUtil::Upper(token.value);
if (token.type != TokenType::KEYWORD || value != "MODEL") {
Expand Down Expand Up @@ -106,7 +106,7 @@ void ModelParser::ParseCreateModel(Tokenizer &tokenizer, std::unique_ptr<QuerySt
}
}

void ModelParser::ParseDeleteModel(Tokenizer &tokenizer, std::unique_ptr<QueryStatement> &statement) {
void ModelParser::ParseDeleteModel(Tokenizer& tokenizer, std::unique_ptr<QueryStatement>& statement) {
Token token = tokenizer.NextToken();
std::string value = StringUtil::Upper(token.value);
if (token.type != TokenType::KEYWORD || value != "MODEL") {
Expand All @@ -129,7 +129,7 @@ void ModelParser::ParseDeleteModel(Tokenizer &tokenizer, std::unique_ptr<QuerySt
}
}

void ModelParser::ParseUpdateModel(Tokenizer &tokenizer, std::unique_ptr<QueryStatement> &statement) {
void ModelParser::ParseUpdateModel(Tokenizer& tokenizer, std::unique_ptr<QueryStatement>& statement) {
Token token = tokenizer.NextToken();
std::string value = StringUtil::Upper(token.value);
if (token.type != TokenType::KEYWORD || value != "MODEL") {
Expand Down Expand Up @@ -206,7 +206,7 @@ void ModelParser::ParseUpdateModel(Tokenizer &tokenizer, std::unique_ptr<QuerySt
}
}

void ModelParser::ParseGetModel(Tokenizer &tokenizer, std::unique_ptr<QueryStatement> &statement) {
void ModelParser::ParseGetModel(Tokenizer& tokenizer, std::unique_ptr<QueryStatement>& statement) {
Token token = tokenizer.NextToken();
std::string value = StringUtil::Upper(token.value);
if (token.type != TokenType::KEYWORD || (value != "MODEL" && value != "MODELS")) {
Expand Down Expand Up @@ -234,55 +234,63 @@ void ModelParser::ParseGetModel(Tokenizer &tokenizer, std::unique_ptr<QueryState
}
}

std::string ModelParser::ToSQL(const QueryStatement &statement) const {
std::ostringstream sql;
std::string ModelParser::ToSQL(const QueryStatement& statement) const {
std::string query;

switch (statement.type) {
case StatementType::CREATE_MODEL: {
const auto &create_stmt = static_cast<const CreateModelStatement &>(statement);
sql << "INSERT INTO flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE(model_name, model, "
"provider_name, model_args) VALUES ('"
<< create_stmt.model_name << "', '" << create_stmt.model << "', '" << create_stmt.provider_name << "', '"
<< create_stmt.model_args << "');";
const auto& create_stmt = static_cast<const CreateModelStatement&>(statement);
query = duckdb_fmt::format(" INSERT INTO "
" flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE "
" (model_name, model, provider_name, model_args) "
" VALUES ('{}', '{}', '{}', '{}');",
create_stmt.model_name, create_stmt.model, create_stmt.provider_name,
create_stmt.model_args.dump());
break;
}
case StatementType::DELETE_MODEL: {
const auto &delete_stmt = static_cast<const DeleteModelStatement &>(statement);
sql << "DELETE FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE WHERE model_name = '"
<< delete_stmt.model_name << "';";
const auto& delete_stmt = static_cast<const DeleteModelStatement&>(statement);
query = duckdb_fmt::format(" DELETE FROM "
" flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE "
" WHERE model_name = '{}';",
delete_stmt.model_name);
break;
}
case StatementType::UPDATE_MODEL: {
const auto &update_stmt = static_cast<const UpdateModelStatement &>(statement);
sql << "UPDATE flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE SET "
<< "model = '" << update_stmt.new_model << "', "
<< "provider_name = '" << update_stmt.provider_name << "', "
<< "model_args = '" << update_stmt.new_model_args << "', "
<< "WHERE model_name = '" << update_stmt.model_name << "';";
const auto& update_stmt = static_cast<const UpdateModelStatement&>(statement);
query = duckdb_fmt::format(" UPDATE flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE "
" SET model = '{}', provider_name = '{}', "
" model_args = '{}' WHERE model_name = '{}'; ",
update_stmt.new_model, update_stmt.provider_name, update_stmt.new_model_args.dump(),
update_stmt.model_name);
break;
}
case StatementType::GET_MODEL: {
const auto &get_stmt = static_cast<const GetModelStatement &>(statement);
sql << "SELECT * FROM flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE WHERE model_name = '"
<< get_stmt.model_name << "'"
<< " UNION ALL "
<< "SELECT * FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE WHERE model_name = '"
<< get_stmt.model_name << "'"
<< ";";
const auto& get_stmt = static_cast<const GetModelStatement&>(statement);
query = duckdb_fmt::format(" SELECT * "
" FROM flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE "
" WHERE model_name = '{}' "
" UNION ALL "
" SELECT * "
" FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE "
" WHERE model_name = '{}';",
get_stmt.model_name, get_stmt.model_name);
break;
}

case StatementType::GET_ALL_MODEL: {
sql << "SELECT * FROM flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE "
<< "UNION ALL "
<< "SELECT * FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE;";
query = duckdb_fmt::format(" SELECT * "
" FROM flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE "
" UNION ALL "
" SELECT * "
" FROM flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE; ");
break;
}
default:
throw std::runtime_error("Unknown statement type.");
}

return sql.str();
return query;
}

} // namespace flockmtl
Loading

0 comments on commit fa4adbc

Please sign in to comment.