Skip to content

Commit

Permalink
update model table to have a single model_args column in JSON format
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas committed Nov 12, 2024
1 parent 93903cb commit a200ef3
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 65 deletions.
51 changes: 30 additions & 21 deletions src/core/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,35 @@ void Config::setup_default_models_config(duckdb::Connection &con, std::string &s
auto result = con.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = '" + schema_name +
"' AND table_name = '" + table_name + "';");
if (result->RowCount() == 0) {
con.Query("CREATE TABLE " + schema_name + "." + table_name +
con.Query("LOAD JSON;"
"CREATE TABLE " +
schema_name + "." + table_name +
" ("
"model_name VARCHAR NOT NULL PRIMARY KEY,"
"model VARCHAR NOT NULL,"
"provider_name VARCHAR NOT NULL,"
"context_window INTEGER NOT NULL,"
"max_output_tokens INTEGER NOT NULL"
"model_name VARCHAR NOT NULL PRIMARY KEY, "
"model VARCHAR NOT NULL, "
"provider_name VARCHAR NOT NULL, "
"model_args JSON NOT NULL"
");");

con.Query("INSERT INTO " + schema_name + "." + table_name +
" (model_name, model, provider_name, context_window, max_output_tokens) VALUES "
"('default', 'gpt-4o-mini', 'openai', 128000, 16384),"
"('gpt-4o-mini', 'gpt-4o-mini', 'openai', 128000, 16384),"
"('gpt-4o', 'gpt-4o', 'openai', 128000, 16384),"
"('text-embedding-3-large', 'text-embedding-3-large', 'openai', " +
std::to_string(Config::default_context_window) + ", " +
std::to_string(Config::default_max_output_tokens) +
"),"
"('text-embedding-3-small', 'text-embedding-3-small', 'openai', " +
std::to_string(Config::default_context_window) + ", " +
std::to_string(Config::default_max_output_tokens) + ");");
con.Query(
"INSERT INTO " + schema_name + "." + table_name +
" (model_name, model, provider_name, model_args) VALUES "
"('default', 'gpt-4o-mini', 'openai', '{\"context_window\": 128000, \"max_output_tokens\": 16384}'),"
"('gpt-4o-mini', 'gpt-4o-mini', 'openai', '{\"context_window\": 128000, \"max_output_tokens\": 16384}'),"
"('gpt-4o', 'gpt-4o', 'openai', '{\"context_window\": 128000, \"max_output_tokens\": 16384}'),"
"('text-embedding-3-large', 'text-embedding-3-large', 'openai', "
"'{\"context_window\": " +
std::to_string(Config::default_context_window) +
", "
"\"max_output_tokens\": " +
std::to_string(Config::default_max_output_tokens) +
"}'),"
"('text-embedding-3-small', 'text-embedding-3-small', 'openai', "
"'{\"context_window\": " +
std::to_string(Config::default_context_window) +
", "
"\"max_output_tokens\": " +
std::to_string(Config::default_max_output_tokens) + "}');");
}
}

Expand All @@ -78,13 +86,14 @@ void Config::setup_user_defined_models_config(duckdb::Connection &con, std::stri
auto result = con.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = '" + schema_name +
"' AND table_name = '" + table_name + "';");
if (result->RowCount() == 0) {
con.Query("CREATE TABLE " + schema_name + "." + table_name +
con.Query("LOAD JSON;"
"CREATE TABLE " +
schema_name + "." + table_name +
" ("
"model_name VARCHAR NOT NULL PRIMARY KEY,"
"model VARCHAR,"
"provider_name VARCHAR NOT NULL,"
"context_window INTEGER NOT NULL,"
"max_output_tokens INTEGER NOT NULL"
"model_args JSON NOT NULL"
");");
}
}
Expand Down
11 changes: 7 additions & 4 deletions src/core/model_manager/model_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <memory>
#include <string>
#include <stdexcept>
#include <nlohmann/json.hpp>

namespace flockmtl {
namespace core {
Expand Down Expand Up @@ -56,7 +57,7 @@ std::tuple<std::string, int32_t, int32_t> ModelManager::GetQueriedModel(Connecti
return {model_name, Config::default_context_window, Config::default_max_output_tokens};
}

std::string query = "SELECT model, context_window, max_output_tokens FROM "
std::string query = "SELECT model, model_args FROM "
"flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE "
"WHERE model_name = '" +
model_name + "'";
Expand All @@ -67,7 +68,7 @@ std::tuple<std::string, int32_t, int32_t> ModelManager::GetQueriedModel(Connecti
auto query_result = con.Query(query);

if (query_result->RowCount() == 0) {
query_result = con.Query("SELECT model, context_window, max_output_tokens FROM "
query_result = con.Query("SELECT model, model_args FROM "
"flockmtl_config.FLOCKMTL_MODEL_DEFAULT_INTERNAL_TABLE WHERE model_name = '" +
model_name + "'");

Expand All @@ -76,8 +77,10 @@ std::tuple<std::string, int32_t, int32_t> ModelManager::GetQueriedModel(Connecti
}
}

return {query_result->GetValue(0, 0).ToString(), query_result->GetValue(1, 0).GetValue<int32_t>(),
query_result->GetValue(2, 0).GetValue<int32_t>()};
auto model = query_result->GetValue(0, 0).ToString();
auto model_args = nlohmann::json::parse(query_result->GetValue(1, 0).ToString());

return {query_result->GetValue(0, 0).ToString(), model_args["context_window"], model_args["max_output_tokens"]};
}

nlohmann::json ModelManager::OllamaCallComplete(const std::string &prompt, const ModelDetails &model_details,
Expand Down
55 changes: 23 additions & 32 deletions src/core/parser/query/model_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,18 @@ void ModelParser::ParseCreateModel(Tokenizer &tokenizer, std::unique_ptr<QuerySt
}

token = tokenizer.NextToken();
if (token.type != TokenType::NUMBER || token.value.empty()) {
throw std::runtime_error("Expected integer value for context_window.");
if (token.type != TokenType::JSON || token.value.empty()) {
throw std::runtime_error("Expected json value for the model_args.");
}
int context_window = std::stoi(token.value);

token = tokenizer.NextToken();
if (token.type != TokenType::SYMBOL || token.value != ",") {
throw std::runtime_error("Expected comma ',' after context_window.");
auto model_args = nlohmann::json::parse(token.value);
const std::set<std::string> expected_keys = {"context_window", "max_output_tokens"};
std::set<std::string> json_keys;
for (auto it = model_args.begin(); it != model_args.end(); ++it) {
json_keys.insert(it.key());
}

token = tokenizer.NextToken();
if (token.type != TokenType::NUMBER || token.value.empty()) {
throw std::runtime_error("Expected integer value for max_output_tokens.");
if (json_keys != expected_keys) {
throw std::runtime_error("Expected keys: context_window, max_output_tokens in model_args.");
}
int max_output_tokens = std::stoi(token.value);

token = tokenizer.NextToken();
if (token.type != TokenType::PARENTHESIS || token.value != ")") {
Expand All @@ -103,8 +100,7 @@ void ModelParser::ParseCreateModel(Tokenizer &tokenizer, std::unique_ptr<QuerySt
create_statement->model_name = model_name;
create_statement->model = model;
create_statement->provider_name = provider_name;
create_statement->context_window = context_window;
create_statement->max_output_tokens = max_output_tokens;
create_statement->model_args = model_args;
statement = std::move(create_statement);
} else {
throw std::runtime_error("Unexpected characters after the closing parenthesis. Only a semicolon is allowed.");
Expand Down Expand Up @@ -180,21 +176,18 @@ void ModelParser::ParseUpdateModel(Tokenizer &tokenizer, std::unique_ptr<QuerySt
}

token = tokenizer.NextToken();
if (token.type != TokenType::NUMBER || token.value.empty()) {
throw std::runtime_error("Expected integer value for new context_window.");
if (token.type != TokenType::JSON || token.value.empty()) {
throw std::runtime_error("Expected json value for the model_args.");
}
int new_context_window = std::stoi(token.value);

token = tokenizer.NextToken();
if (token.type != TokenType::SYMBOL || token.value != ",") {
throw std::runtime_error("Expected comma ',' after context_window.");
auto new_model_args = nlohmann::json::parse(token.value);
const std::set<std::string> expected_keys = {"context_window", "max_output_tokens"};
std::set<std::string> json_keys;
for (auto it = new_model_args.begin(); it != new_model_args.end(); ++it) {
json_keys.insert(it.key());
}

token = tokenizer.NextToken();
if (token.type != TokenType::NUMBER || token.value.empty()) {
throw std::runtime_error("Expected integer value for new max_output_tokens.");
if (json_keys != expected_keys) {
throw std::runtime_error("Expected keys: context_window, max_output_tokens in model_args.");
}
int new_max_output_tokens = std::stoi(token.value);

token = tokenizer.NextToken();
if (token.type != TokenType::PARENTHESIS || token.value != ")") {
Expand All @@ -207,8 +200,7 @@ void ModelParser::ParseUpdateModel(Tokenizer &tokenizer, std::unique_ptr<QuerySt
update_statement->new_model = new_model;
update_statement->model_name = model_name;
update_statement->provider_name = provider_name;
update_statement->new_context_window = new_context_window;
update_statement->new_max_output_tokens = new_max_output_tokens;
update_statement->new_model_args = new_model_args;
statement = std::move(update_statement);
} else {
throw std::runtime_error("Unexpected characters after the closing parenthesis. Only a semicolon is allowed.");
Expand Down Expand Up @@ -250,9 +242,9 @@ std::string ModelParser::ToSQL(const QueryStatement &statement) const {
case StatementType::CREATE_MODEL: {
const auto &create_stmt = static_cast<const CreateModelStatement &>(statement);
sql << "INSERT INTO flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE(model_name, model, "
"provider_name, context_window, max_output_tokens) VALUES ('"
"provider_name, model_args) VALUES ('"
<< create_stmt.model_name << "', '" << create_stmt.model << "', '" << create_stmt.provider_name << "', '"
<< create_stmt.context_window << "', '" << create_stmt.max_output_tokens << "');";
<< create_stmt.model_args << "');";
break;
}
case StatementType::DELETE_MODEL: {
Expand All @@ -266,8 +258,7 @@ std::string ModelParser::ToSQL(const QueryStatement &statement) const {
sql << "UPDATE flockmtl_config.FLOCKMTL_MODEL_USER_DEFINED_INTERNAL_TABLE SET "
<< "model = '" << update_stmt.new_model << "', "
<< "provider_name = '" << update_stmt.provider_name << "', "
<< "context_window = " << update_stmt.new_context_window << ", "
<< "max_output_tokens = " << update_stmt.new_max_output_tokens << ", "
<< "model_args = '" << update_stmt.new_model_args << "', "
<< "WHERE model_name = '" << update_stmt.model_name << "';";
break;
}
Expand Down
26 changes: 25 additions & 1 deletion src/core/parser/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,27 @@ Token Tokenizer::ParseStringLiteral() {
return {TokenType::STRING_LITERAL, value};
}

Token Tokenizer::ParseJson() {
if (query_[position_] != '{') {
throw std::runtime_error("JSON should start with a curly brace.");
}
int start = position_++;
int brace_count = 1;
while (position_ < query_.size() && brace_count > 0) {
if (query_[position_] == '{') {
++brace_count;
} else if (query_[position_] == '}') {
--brace_count;
}
++position_;
}
if (brace_count > 0) {
throw std::runtime_error("Unterminated JSON.");
}
std::string value = query_.substr(start, position_ - start);
return {TokenType::JSON, value};
}

// Parse a keyword (word made of letters)
Token Tokenizer::ParseKeyword() {
int start = position_;
Expand Down Expand Up @@ -81,6 +102,8 @@ Token Tokenizer::GetNextToken() {
char ch = query_[position_];
if (ch == '\'') {
return ParseStringLiteral();
} else if (ch == '{') {
return ParseJson();
} else if (std::isalpha(ch)) {
return ParseKeyword();
} else if (ch == ';' || ch == ',') {
Expand All @@ -106,6 +129,8 @@ std::string TokenTypeToString(TokenType type) {
return "KEYWORD";
case TokenType::STRING_LITERAL:
return "STRING_LITERAL";
case TokenType::JSON:
return "JSON";
case TokenType::SYMBOL:
return "SYMBOL";
case TokenType::NUMBER:
Expand All @@ -115,7 +140,6 @@ std::string TokenTypeToString(TokenType type) {
case TokenType::END_OF_FILE:
return "END_OF_FILE";
case TokenType::UNKNOWN:
return "UNKNOWN";
default:
return "UNKNOWN";
}
Expand Down
7 changes: 3 additions & 4 deletions src/include/flockmtl/core/parser/query/model_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <memory>
#include <string>
#include <vector>
#include <nlohmann/json.hpp>

namespace flockmtl {

Expand All @@ -21,8 +22,7 @@ class CreateModelStatement : public QueryStatement {
std::string model_name;
std::string model;
std::string provider_name;
int context_window;
int max_output_tokens;
nlohmann::json model_args;
};

class DeleteModelStatement : public QueryStatement {
Expand All @@ -42,8 +42,7 @@ class UpdateModelStatement : public QueryStatement {
std::string model_name;
std::string new_model;
std::string provider_name;
int new_context_window;
int new_max_output_tokens;
nlohmann::json new_model_args;
};

class GetModelStatement : public QueryStatement {
Expand Down
7 changes: 4 additions & 3 deletions src/include/flockmtl/core/parser/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <string>

// Define token types
enum class TokenType { KEYWORD, STRING_LITERAL, SYMBOL, NUMBER, PARENTHESIS, END_OF_FILE, UNKNOWN };
enum class TokenType { KEYWORD, STRING_LITERAL, JSON, SYMBOL, NUMBER, PARENTHESIS, END_OF_FILE, UNKNOWN };

// Token structure
struct Token {
Expand All @@ -27,8 +27,9 @@ class Tokenizer {
Token ParseStringLiteral();
Token ParseKeyword();
Token ParseSymbol();
Token ParseNumber(); // Added for handling numbers
Token ParseParenthesis(); // Added for handling parentheses
Token ParseNumber();
Token ParseParenthesis();
Token ParseJson();
Token GetNextToken();
};

Expand Down

0 comments on commit a200ef3

Please sign in to comment.