Skip to content

Commit

Permalink
Refactor and Improve Model Manager Implementation (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas authored Nov 25, 2024
1 parent 5536204 commit 9f813a6
Show file tree
Hide file tree
Showing 43 changed files with 646 additions and 891 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_subdirectory(core)
add_subdirectory(model_manager)
add_subdirectory(prompt_manager)

set(EXTENSION_SOURCES
Expand Down
6 changes: 2 additions & 4 deletions src/core/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
add_subdirectory(functions)
add_subdirectory(model_manager)
add_subdirectory(parser)
add_subdirectory(config)

set(EXTENSION_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/module.cpp
${EXTENSION_SOURCES}
PARENT_SCOPE)
${CMAKE_CURRENT_SOURCE_DIR}/module.cpp ${EXTENSION_SOURCES}
PARENT_SCOPE)
27 changes: 14 additions & 13 deletions src/core/functions/aggregate/llm_agg.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <flockmtl/core/functions/aggregate/llm_agg.hpp>
#include <flockmtl/core/functions/batch_response_builder.hpp>
#include "flockmtl/core/module.hpp"
#include "flockmtl/core/model_manager/model_manager.hpp"
#include "flockmtl/core/model_manager/tiktoken.hpp"
#include "flockmtl/model_manager/model.hpp"
#include "flockmtl/model_manager/tiktoken.hpp"
#include <flockmtl/core/config/config.hpp>
#include <vector>

Expand All @@ -19,13 +19,13 @@ void LlmAggState::Combine(const LlmAggState &source) {
}
}

LlmFirstOrLast::LlmFirstOrLast(std::string &model, int model_context_size, std::string &search_query,
AggregateFunctionType function_type)
: model(model), model_context_size(model_context_size), search_query(search_query), function_type(function_type) {
LlmFirstOrLast::LlmFirstOrLast(Model &model, std::string &search_query, AggregateFunctionType function_type)
: model(model), search_query(search_query), function_type(function_type) {

llm_first_or_last_template = PromptManager::GetTemplate(function_type);
auto num_tokens_meta_and_search_query = calculateFixedTokens();

auto model_context_size = model.GetModelDetails().context_window;
if (num_tokens_meta_and_search_query > model_context_size) {
throw std::runtime_error("Fixed tokens exceed model context size");
}
Expand All @@ -44,7 +44,7 @@ int LlmFirstOrLast::GetFirstOrLastTupleId(const nlohmann::json &tuples) {
nlohmann::json data;
auto markdown_tuples = ConstructMarkdownArrayTuples(tuples);
auto prompt = PromptManager::Render(search_query, markdown_tuples, function_type);
auto response = ModelManager::CallComplete(prompt, LlmAggOperation::model_details);
auto response = model.CallComplete(prompt);
return response["selected"].get<int>();
}

Expand Down Expand Up @@ -77,7 +77,7 @@ nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json &tuples) {
}

// Static member initialization
ModelDetails LlmAggOperation::model_details {};
Model &LlmAggOperation::model(*(new Model(nlohmann::json())));
std::string LlmAggOperation::search_query;

std::unordered_map<void *, std::shared_ptr<LlmAggState>> LlmAggOperation::state_map;
Expand All @@ -98,13 +98,14 @@ void LlmAggOperation::Operation(Vector inputs[], AggregateInputData &aggr_input_
throw std::runtime_error("Expected a struct type for model details");
}
auto model_details_json = CastVectorOfStructsToJson(inputs[0], 1)[0];
LlmAggOperation::model_details = ModelManager::CreateModelDetails(CoreModule::GetConnection(), model_details_json);
LlmAggOperation::model = Model(model_details_json);

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt details");
}
auto prompt_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
search_query = CreatePromptDetails(CoreModule::GetConnection(), prompt_details_json).prompt;
auto connection = CoreModule::GetConnection();
search_query = CreatePromptDetails(connection, prompt_details_json).prompt;

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt inputs");
Expand Down Expand Up @@ -153,8 +154,7 @@ void LlmAggOperation::FinalizeResults(Vector &states, AggregateInputData &aggr_i
tuples_with_ids.push_back(tuple_with_id);
}

LlmFirstOrLast llm_first_or_last(LlmAggOperation::model_details.model, Config::default_context_window,
search_query, function_type);
LlmFirstOrLast llm_first_or_last(LlmAggOperation::model, search_query, function_type);
auto response = llm_first_or_last.Evaluate(tuples_with_ids);
result.SetValue(idx, response.dump());
}
Expand All @@ -180,13 +180,14 @@ void LlmAggOperation::SimpleUpdate(Vector inputs[], AggregateInputData &aggr_inp
throw std::runtime_error("Expected a struct type for model details");
}
auto model_details_json = CastVectorOfStructsToJson(inputs[0], 1)[0];
LlmAggOperation::model_details = ModelManager::CreateModelDetails(CoreModule::GetConnection(), model_details_json);
LlmAggOperation::model = Model(model_details_json);

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt details");
}
auto prompt_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
search_query = CreatePromptDetails(CoreModule::GetConnection(), prompt_details_json).prompt;
auto connection = CoreModule::GetConnection();
search_query = CreatePromptDetails(connection, prompt_details_json).prompt;

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt inputs");
Expand Down
35 changes: 15 additions & 20 deletions src/core/functions/aggregate/llm_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
#include <flockmtl/core/functions/aggregate/llm_agg.hpp>
#include <flockmtl/core/functions/batch_response_builder.hpp>
#include "flockmtl/prompt_manager/prompt_manager.hpp"
#include <flockmtl/core/model_manager/tiktoken.hpp>
#include <flockmtl/model_manager/tiktoken.hpp>

namespace flockmtl {
namespace core {

class LlmReduce {
public:
std::string model;
int model_context_size;
Model &model;
std::string reduce_query;
std::string llm_reduce_template;
int available_tokens;
ModelDetails model_details;
AggregateFunctionType function_type;

int calculateFixedTokens() const {
Expand All @@ -27,14 +25,13 @@ class LlmReduce {
return num_tokens_meta_and_reduce_query;
}

LlmReduce(std::string &model, int model_context_size, std::string &reduce_query, ModelDetails model_details)
: model(model), model_context_size(model_context_size), reduce_query(reduce_query),
model_details(model_details) {
LlmReduce(Model &model, std::string &reduce_query) : model(model), reduce_query(reduce_query) {

function_type = AggregateFunctionType::REDUCE;
llm_reduce_template = PromptManager::GetTemplate(function_type);
auto num_tokens_meta_and_reduce_query = calculateFixedTokens();

auto model_context_size = model.GetModelDetails().context_window;
if (num_tokens_meta_and_reduce_query > model_context_size) {
throw std::runtime_error("Fixed tokens exceed model context size");
}
Expand All @@ -46,7 +43,7 @@ class LlmReduce {
nlohmann::json data;
auto markdown_tuples = ConstructMarkdownArrayTuples(tuples);
auto prompt = PromptManager::Render(reduce_query, markdown_tuples, function_type);
auto response = ModelManager::CallComplete(prompt, model_details);
auto response = model.CallComplete(prompt);
return response["output"];
};

Expand Down Expand Up @@ -78,7 +75,7 @@ class LlmReduce {
};

struct LlmReduceOperation {
static ModelDetails model_details;
static Model &model;
static std::string reduce_query;
static std::unordered_map<void *, std::shared_ptr<LlmAggState>> state_map;

Expand All @@ -98,14 +95,14 @@ struct LlmReduceOperation {
throw std::runtime_error("Expected a struct type for model details");
}
auto model_details_json = CastVectorOfStructsToJson(inputs[0], 1)[0];
LlmReduceOperation::model_details =
ModelManager::CreateModelDetails(CoreModule::GetConnection(), model_details_json);
LlmReduceOperation::model = Model(model_details_json);

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt details");
}
auto prompt_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
reduce_query = CreatePromptDetails(CoreModule::GetConnection(), prompt_details_json).prompt;
auto connection = CoreModule::GetConnection();
reduce_query = CreatePromptDetails(connection, prompt_details_json).prompt;

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt inputs");
Expand Down Expand Up @@ -134,8 +131,7 @@ struct LlmReduceOperation {
auto source_state = state_map[source_ptr];
auto target_state = state_map[target_ptr];

LlmReduce llm_reduce(LlmReduceOperation::model_details.model, Config::default_context_window, reduce_query,
LlmReduceOperation::model_details);
LlmReduce llm_reduce(LlmReduceOperation::model, reduce_query);
auto result = llm_reduce.ReduceLoop(source_state->value);
target_state->Update(result);
}
Expand All @@ -150,8 +146,7 @@ struct LlmReduceOperation {
auto state_ptr = states_vector[idx];
auto state = state_map[state_ptr];

LlmReduce llm_reduce(LlmReduceOperation::model_details.model, Config::default_context_window, reduce_query,
LlmReduceOperation::model_details);
LlmReduce llm_reduce(LlmReduceOperation::model, reduce_query);
auto response = llm_reduce.ReduceLoop(state->value);
result.SetValue(idx, response.dump());
}
Expand All @@ -163,14 +158,14 @@ struct LlmReduceOperation {
throw std::runtime_error("Expected a struct type for model details");
}
auto model_details_json = CastVectorOfStructsToJson(inputs[0], 1)[0];
LlmReduceOperation::model_details =
ModelManager::CreateModelDetails(CoreModule::GetConnection(), model_details_json);
LlmReduceOperation::model = Model(model_details_json);

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt details");
}
auto prompt_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
reduce_query = CreatePromptDetails(CoreModule::GetConnection(), prompt_details_json).prompt;
auto connection = CoreModule::GetConnection();
reduce_query = CreatePromptDetails(connection, prompt_details_json).prompt;

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt inputs");
Expand All @@ -189,7 +184,7 @@ struct LlmReduceOperation {
static bool IgnoreNull() { return true; }
};

ModelDetails LlmReduceOperation::model_details;
Model &LlmReduceOperation::model(*(new Model(nlohmann::json())));
std::string LlmReduceOperation::reduce_query;
std::unordered_map<void *, std::shared_ptr<LlmAggState>> LlmReduceOperation::state_map;

Expand Down
13 changes: 6 additions & 7 deletions src/core/functions/aggregate/llm_rerank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#include "flockmtl/core/module.hpp"
#include <flockmtl/common.hpp>
#include <flockmtl/core/functions/aggregate.hpp>
#include <flockmtl/core/model_manager/model_manager.hpp>
#include <flockmtl/model_manager/model.hpp>
#include <flockmtl_extension.hpp>
#include <flockmtl/core/model_manager/tiktoken.hpp>
#include <flockmtl/model_manager/tiktoken.hpp>
#include <flockmtl/core/functions/aggregate/llm_agg.hpp>
#include "flockmtl/prompt_manager/prompt_manager.hpp"
#include <flockmtl/core/config/config.hpp>
Expand All @@ -14,14 +14,14 @@
namespace flockmtl {
namespace core {

LlmReranker::LlmReranker(std::string &model, int model_context_size, std::string &search_query)
: model(model), model_context_size(model_context_size), search_query(search_query) {
LlmReranker::LlmReranker(Model &model, std::string &search_query) : model(model), search_query(search_query) {

function_type = AggregateFunctionType::RERANK;
llm_reranking_template = PromptManager::GetTemplate(function_type);

auto num_tokens_meta_and_search_query = CalculateFixedTokens();

auto model_context_size = model.GetModelDetails().context_window;
if (num_tokens_meta_and_search_query > model_context_size) {
throw std::runtime_error("Fixed tokens exceed model context size");
}
Expand Down Expand Up @@ -87,7 +87,7 @@ vector<int> LlmReranker::LlmRerankWithSlidingWindow(const nlohmann::json &tuples
nlohmann::json data;
auto markdown_tuples = ConstructMarkdownArrayTuples(tuples);
auto prompt = PromptManager::Render(search_query, markdown_tuples, function_type);
auto response = ModelManager::CallComplete(prompt, LlmAggOperation::model_details);
auto response = model.CallComplete(prompt);
return response["ranking"].get<vector<int>>();
};

Expand All @@ -105,8 +105,7 @@ void LlmAggOperation::RerankerFinalize(Vector &states, AggregateInputData &aggr_
tuples_with_ids.push_back(state->value[i]);
}

LlmReranker llm_reranker(LlmAggOperation::model_details.model, Config::default_context_window,
LlmAggOperation::search_query);
LlmReranker llm_reranker(LlmAggOperation::model, LlmAggOperation::search_query);

auto reranked_tuples = llm_reranker.SlidingWindowRerank(tuples_with_ids);

Expand Down
16 changes: 8 additions & 8 deletions src/core/functions/batch_response_builder.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <flockmtl/core/module.hpp>
#include <flockmtl/core/functions/batch_response_builder.hpp>
#include <flockmtl/core/model_manager/tiktoken.hpp>
#include <flockmtl/model_manager/tiktoken.hpp>
#include "flockmtl/prompt_manager/prompt_manager.hpp"

namespace flockmtl {
Expand Down Expand Up @@ -90,22 +90,22 @@ PromptDetails CreatePromptDetails(Connection &con, const nlohmann::json prompt_d
}

nlohmann::json Complete(const nlohmann::json &tuples, const std::string &user_prompt, ScalarFunctionType function_type,
ModelDetails &model_details) {
Model &model) {
nlohmann::json data;
auto tuples_markdown = ConstructMarkdownArrayTuples(tuples);
auto prompt = PromptManager::Render(user_prompt, tuples_markdown, function_type);
auto response = ModelManager::CallComplete(prompt, model_details);
auto response = model.CallComplete(prompt);
return response["tuples"];
};

nlohmann::json BatchAndComplete(std::vector<nlohmann::json> &tuples, Connection &con, std::string user_prompt,
ScalarFunctionType function_type, ModelDetails &model_details) {
ScalarFunctionType function_type, Model &model) {
auto llm_template = PromptManager::GetTemplate(function_type);

int num_tokens_meta_and_user_pormpt = 0;
num_tokens_meta_and_user_pormpt += Tiktoken::GetNumTokens(user_prompt);
num_tokens_meta_and_user_pormpt += Tiktoken::GetNumTokens(llm_template);
int available_tokens = model_details.context_window - num_tokens_meta_and_user_pormpt;
int available_tokens = model.GetModelDetails().context_window - num_tokens_meta_and_user_pormpt;

auto responses = nlohmann::json::array();

Expand Down Expand Up @@ -133,8 +133,8 @@ nlohmann::json BatchAndComplete(std::vector<nlohmann::json> &tuples, Connection

nlohmann::json response;
try {
response = Complete(batch_tuples, user_prompt, function_type, model_details);
} catch (const LengthExceededError &e) {
response = Complete(batch_tuples, user_prompt, function_type, model);
} catch (const ExceededMaxOutputTokensError &e) {
batch_tuples.clear();
accumulated_tuples_tokens = 0;
auto new_batch_size = int(batch_size * 0.1);
Expand All @@ -145,7 +145,7 @@ nlohmann::json BatchAndComplete(std::vector<nlohmann::json> &tuples, Connection
}
auto output_tokens_per_tuple = Tiktoken::GetNumTokens(response.dump()) / batch_tuples.size();

batch_size = model_details.max_output_tokens / output_tokens_per_tuple;
batch_size = model.GetModelDetails().max_output_tokens / output_tokens_per_tuple;
batch_tuples.clear();
accumulated_tuples_tokens = 0;

Expand Down
2 changes: 1 addition & 1 deletion src/core/functions/scalar/fusion_relative.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <flockmtl/common.hpp>
#include <flockmtl/core/functions/scalar.hpp>
#include <flockmtl/core/model_manager/model_manager.hpp>
#include <flockmtl/model_manager/model.hpp>
#include <flockmtl/core/parser/llm_response.hpp>
#include <flockmtl/core/parser/scalar.hpp>
#include <flockmtl_extension.hpp>
Expand Down
9 changes: 4 additions & 5 deletions src/core/functions/scalar/llm_complete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <flockmtl/core/functions/batch_response_builder.hpp>
#include <flockmtl/common.hpp>
#include <flockmtl/core/functions/scalar.hpp>
#include <flockmtl/core/model_manager/model_manager.hpp>
#include <flockmtl/model_manager/model.hpp>
#include <flockmtl/core/parser/llm_response.hpp>
#include <flockmtl/core/parser/scalar.hpp>
#include <flockmtl/core/config/config.hpp>
Expand All @@ -19,20 +19,19 @@ static void LlmCompleteScalarFunction(DataChunk &args, ExpressionState &state, V
CoreScalarParsers::LlmCompleteScalarParser(args);

auto model_details_json = CoreScalarParsers::Struct2Json(args.data[0], 1)[0];
auto model_details = ModelManager::CreateModelDetails(con, model_details_json);
Model model(model_details_json);
auto prompt_details_json = CoreScalarParsers::Struct2Json(args.data[1], 1)[0];
auto prompt_details = CreatePromptDetails(con, prompt_details_json);

if (args.ColumnCount() == 2) {
auto template_str = prompt_details.prompt;
auto response = ModelManager::CallComplete(template_str, model_details, false);
auto response = model.CallComplete(template_str, false);

result.SetValue(0, response.dump());
} else {
auto tuples = CoreScalarParsers::Struct2Json(args.data[2], args.size());

auto responses =
BatchAndComplete(tuples, con, prompt_details.prompt, ScalarFunctionType::COMPLETE, model_details);
auto responses = BatchAndComplete(tuples, con, prompt_details.prompt, ScalarFunctionType::COMPLETE, model);

auto index = 0;
Vector vec(LogicalType::VARCHAR, args.size());
Expand Down
Loading

0 comments on commit 9f813a6

Please sign in to comment.