Skip to content

Commit

Permalink
Feat/batching on output max size (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas authored Nov 15, 2024
1 parent 2164511 commit 25958bb
Show file tree
Hide file tree
Showing 22 changed files with 337 additions and 362 deletions.
50 changes: 31 additions & 19 deletions src/core/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,35 @@ void Config::setup_default_models_config(duckdb::Connection &con, std::string &s
auto result = con.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = '" + schema_name +
"' AND table_name = '" + table_name + "';");
if (result->RowCount() == 0) {
con.Query("CREATE TABLE " + schema_name + "." + table_name +
con.Query("LOAD JSON;"
"CREATE TABLE " +
schema_name + "." + table_name +
" ("
"model_name VARCHAR NOT NULL PRIMARY KEY,"
"model VARCHAR,"
"provider_name VARCHAR NOT NULL,"
"max_tokens INTEGER NOT NULL"
"model_name VARCHAR NOT NULL PRIMARY KEY, "
"model VARCHAR NOT NULL, "
"provider_name VARCHAR NOT NULL, "
"model_args JSON NOT NULL"
");");

con.Query("INSERT INTO " + schema_name + "." + table_name +
" (model_name, model, provider_name, max_tokens) VALUES "
"('default', 'gpt-4o-mini', 'openai', 128000),"
"('gpt-4o-mini', 'gpt-4o-mini', 'openai', 128000),"
"('gpt-4o', 'gpt-4o', 'openai', 128000),"
"('text-embedding-3-large', 'text-embedding-3-large', 'openai', " +
std::to_string(Config::default_max_tokens) +
"),"
"('text-embedding-3-small', 'text-embedding-3-small', 'openai', " +
std::to_string(Config::default_max_tokens) + ");");
con.Query(
"INSERT INTO " + schema_name + "." + table_name +
" (model_name, model, provider_name, model_args) VALUES "
"('default', 'gpt-4o-mini', 'openai', '{\"context_window\": 128000, \"max_output_tokens\": 16384}'),"
"('gpt-4o-mini', 'gpt-4o-mini', 'openai', '{\"context_window\": 128000, \"max_output_tokens\": 16384}'),"
"('gpt-4o', 'gpt-4o', 'openai', '{\"context_window\": 128000, \"max_output_tokens\": 16384}'),"
"('text-embedding-3-large', 'text-embedding-3-large', 'openai', "
"'{\"context_window\": " +
std::to_string(Config::default_context_window) +
", "
"\"max_output_tokens\": " +
std::to_string(Config::default_max_output_tokens) +
"}'),"
"('text-embedding-3-small', 'text-embedding-3-small', 'openai', "
"'{\"context_window\": " +
std::to_string(Config::default_context_window) +
", "
"\"max_output_tokens\": " +
std::to_string(Config::default_max_output_tokens) + "}');");
}
}

Expand All @@ -75,13 +86,14 @@ void Config::setup_user_defined_models_config(duckdb::Connection &con, std::stri
auto result = con.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = '" + schema_name +
"' AND table_name = '" + table_name + "';");
if (result->RowCount() == 0) {
con.Query("CREATE TABLE " + schema_name + "." + table_name +
con.Query("LOAD JSON;"
"CREATE TABLE " +
schema_name + "." + table_name +
" ("
"model_name VARCHAR NOT NULL,"
"model_name VARCHAR NOT NULL PRIMARY KEY,"
"model VARCHAR,"
"provider_name VARCHAR NOT NULL,"
"max_tokens INTEGER NOT NULL,"
"PRIMARY KEY (model_name, provider_name)"
"model_args JSON NOT NULL"
");");
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/functions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ add_subdirectory(scalar)
add_subdirectory(aggregate)

set(EXTENSION_SOURCES
${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/prompt_builder.cpp
${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/batch_response_builder.cpp
PARENT_SCOPE)
65 changes: 20 additions & 45 deletions src/core/functions/aggregate/llm_agg.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <flockmtl/core/functions/aggregate/llm_agg.hpp>
#include <flockmtl/core/functions/prompt_builder.hpp>
#include <flockmtl/core/functions/batch_response_builder.hpp>
#include "flockmtl/core/module.hpp"
#include "flockmtl/core/model_manager/model_manager.hpp"
#include "flockmtl/core/model_manager/tiktoken.hpp"
Expand Down Expand Up @@ -49,58 +49,33 @@ int LlmFirstOrLast::GetFirstOrLastTupleId(const nlohmann::json &tuples) {
data["tuples"] = tuples;
data["search_query"] = search_query;
auto prompt = env.render(llm_first_or_last_template, data);

auto response = ModelManager::CallComplete(prompt, LlmAggOperation::model_details);
return response["selected"].get<int>();
}

nlohmann::json LlmFirstOrLast::Evaluate(nlohmann::json &tuples) {
int num_tuples;
vector<int> num_tuples_per_batch;
int num_used_tokens;
int batch_size;
int batch_index;

while (tuples.size() > 1) {
num_tuples = tuples.size();
num_used_tokens = 0;
batch_size = 0;
batch_index = 0;

for (int i = 0; i < num_tuples; i++) {
num_used_tokens += Tiktoken::GetNumTokens(tuples[i].dump());
batch_size++;

if (num_used_tokens >= available_tokens) {
num_tuples_per_batch.push_back(batch_size);
batch_index++;
num_used_tokens = 0;
batch_size = 0;
} else if (i == num_tuples - 1) {
num_tuples_per_batch.push_back(batch_size);
batch_index++;
}
}

auto responses = nlohmann::json::array();
auto num_batches = batch_index;
auto accumulated_tuples_tokens = 0u;
auto batch_tuples = nlohmann::json::array();
int start_index = 0;

for (auto i = 0; i < num_batches; i++) {
auto start_index = i * num_tuples_per_batch[i];
auto end_index = start_index + num_tuples_per_batch[i];
auto batch = nlohmann::json::array();

for (auto j = start_index; j < end_index; j++) {
batch.push_back(tuples[j]);
do {
accumulated_tuples_tokens = Tiktoken::GetNumTokens(batch_tuples.dump());
while (accumulated_tuples_tokens < available_tokens && start_index < tuples.size()) {
auto num_tokens = Tiktoken::GetNumTokens(tuples[start_index].dump());
if (accumulated_tuples_tokens + num_tokens > available_tokens) {
break;
}

auto result_idx = GetFirstOrLastTupleId(batch);
responses.push_back(batch[result_idx]);
batch_tuples.push_back(tuples[start_index]);
accumulated_tuples_tokens += num_tokens;
start_index++;
}
tuples = responses;
};
auto result_idx = GetFirstOrLastTupleId(batch_tuples);
batch_tuples.clear();
batch_tuples.push_back(tuples[result_idx]);
} while (start_index < tuples.size());

return tuples[0]["content"];
return batch_tuples[0]["content"];
}

// Static member initialization
Expand Down Expand Up @@ -178,8 +153,8 @@ void LlmAggOperation::FinalizeResults(Vector &states, AggregateInputData &aggr_i
tuples_with_ids.push_back(tuple_with_id);
}

LlmFirstOrLast llm_first_or_last(LlmAggOperation::model_details.model, Config::default_max_tokens, search_query,
llm_prompt_template);
LlmFirstOrLast llm_first_or_last(LlmAggOperation::model_details.model, Config::default_context_window,
search_query, llm_prompt_template);
auto response = llm_first_or_last.Evaluate(tuples_with_ids);
result.SetValue(idx, response.dump());
}
Expand Down
36 changes: 19 additions & 17 deletions src/core/functions/aggregate/llm_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
#include <flockmtl/core/config/config.hpp>
#include <flockmtl/core/functions/aggregate.hpp>
#include <flockmtl/core/functions/aggregate/llm_agg.hpp>
#include <flockmtl/core/functions/prompt_builder.hpp>
#include <flockmtl/core/functions/batch_response_builder.hpp>
#include <templates/llm_reduce_prompt_template.hpp>
#include <flockmtl/core/model_manager/tiktoken.hpp>

namespace flockmtl {
namespace core {
Expand Down Expand Up @@ -53,27 +54,28 @@ class LlmReduce {
};

nlohmann::json ReduceLoop(vector<nlohmann::json> &tuples) {
auto accumulated_rows_tokens = 0u;
auto window_tuples = nlohmann::json::array();
int signed start_index = tuples.size() - 1;
auto accumulated_tuples_tokens = 0u;
auto batch_tuples = nlohmann::json::array();
int start_index = 0;

do {
accumulated_rows_tokens = Tiktoken::GetNumTokens(window_tuples.dump());
while (available_tokens - accumulated_rows_tokens > 0 && start_index >= 0) {
accumulated_tuples_tokens = Tiktoken::GetNumTokens(batch_tuples.dump());
while (accumulated_tuples_tokens < available_tokens && start_index < tuples.size()) {
auto num_tokens = Tiktoken::GetNumTokens(tuples[start_index].dump());
if (accumulated_rows_tokens + num_tokens > available_tokens) {
if (accumulated_tuples_tokens + num_tokens > available_tokens) {
break;
}
window_tuples.push_back(tuples[start_index]);
accumulated_rows_tokens += num_tokens;
start_index--;
batch_tuples.push_back(tuples[start_index]);
accumulated_tuples_tokens += num_tokens;
start_index++;
}
auto response = Reduce(window_tuples);
window_tuples.clear();
window_tuples.push_back(response);
} while (start_index >= 0);
auto response = Reduce(batch_tuples);
batch_tuples.clear();
batch_tuples.push_back(response);
accumulated_tuples_tokens = 0;
} while (start_index < tuples.size());

return window_tuples[0];
return batch_tuples[0];
}
};

Expand Down Expand Up @@ -132,7 +134,7 @@ struct LlmReduceOperation {
auto target_state = state_map[target_ptr];

auto template_str = string(llm_reduce_prompt_template);
LlmReduce llm_reduce(LlmReduceOperation::model_details.model, Config::default_max_tokens, reduce_query,
LlmReduce llm_reduce(LlmReduceOperation::model_details.model, Config::default_context_window, reduce_query,
template_str, LlmReduceOperation::model_details);
auto result = llm_reduce.ReduceLoop(source_state->value);
target_state->Update(result);
Expand All @@ -149,7 +151,7 @@ struct LlmReduceOperation {
auto state = state_map[state_ptr];

auto template_str = string(llm_reduce_prompt_template);
LlmReduce llm_reduce(LlmReduceOperation::model_details.model, Config::default_max_tokens, reduce_query,
LlmReduce llm_reduce(LlmReduceOperation::model_details.model, Config::default_context_window, reduce_query,
template_str, LlmReduceOperation::model_details);
auto response = llm_reduce.ReduceLoop(state->value);
result.SetValue(idx, response.dump());
Expand Down
4 changes: 2 additions & 2 deletions src/core/functions/aggregate/llm_rerank.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <nlohmann/json.hpp>
#include <inja/inja.hpp>
#include <flockmtl/core/functions/prompt_builder.hpp>
#include <flockmtl/core/functions/batch_response_builder.hpp>
#include "flockmtl/core/module.hpp"
#include "templates/llm_rerank_prompt_template.hpp"
#include <flockmtl/common.hpp>
Expand Down Expand Up @@ -112,7 +112,7 @@ void LlmAggOperation::RerankerFinalize(Vector &states, AggregateInputData &aggr_
tuples_with_ids.push_back(state->value[i]);
}

LlmReranker llm_reranker(LlmAggOperation::model_details.model, Config::default_max_tokens,
LlmReranker llm_reranker(LlmAggOperation::model_details.model, Config::default_context_window,
LlmAggOperation::search_query, llm_rerank_prompt_template_str);

auto reranked_tuples = llm_reranker.SlidingWindowRerank(tuples_with_ids);
Expand Down
103 changes: 103 additions & 0 deletions src/core/functions/batch_response_builder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <inja/inja.hpp>
#include <flockmtl/core/functions/batch_response_builder.hpp>
#include <flockmtl/core/model_manager/tiktoken.hpp>

namespace flockmtl {
namespace core {

std::vector<nlohmann::json> CastVectorOfStructsToJson(Vector &struct_vector, int size) {
vector<nlohmann::json> vector_json;
for (auto i = 0; i < size; i++) {
nlohmann::json json;
for (auto j = 0; j < StructType::GetChildCount(struct_vector.GetType()); j++) {
auto key = StructType::GetChildName(struct_vector.GetType(), j);
auto value = StructValue::GetChildren(struct_vector.GetValue(i))[j].ToString();
json[key] = value;
}
vector_json.push_back(json);
}
return vector_json;
}

nlohmann::json Complete(const nlohmann::json &tuples, const std::string &user_prompt, const std::string &llm_template,
const ModelDetails &model_details) {
inja::Environment env;
nlohmann::json data;
data["user_prompt"] = user_prompt;
data["tuples"] = tuples;
auto prompt = env.render(llm_template, data);

auto response = ModelManager::CallComplete(prompt, model_details);

return response["tuples"];
};

nlohmann::json BatchAndComplete(std::vector<nlohmann::json> &tuples, Connection &con, std::string user_prompt_name,
const std::string &llm_template, const ModelDetails &model_details) {

auto query_result =
con.Query("SELECT prompt FROM flockmtl_config.FLOCKMTL_PROMPT_INTERNAL_TABLE WHERE prompt_name = '" +
user_prompt_name + "'");
if (query_result->RowCount() == 0) {
throw std::runtime_error("Prompt not found");
}
auto user_prompt = query_result->GetValue(0, 0).ToString();

int num_tokens_meta_and_user_pormpt = 0;
num_tokens_meta_and_user_pormpt += Tiktoken::GetNumTokens(user_prompt);
num_tokens_meta_and_user_pormpt += Tiktoken::GetNumTokens(llm_template);
int available_tokens = model_details.context_window - num_tokens_meta_and_user_pormpt;

auto responses = nlohmann::json::array();

if (available_tokens < 0) {
throw std::runtime_error("The total number of tokens in the prompt exceeds the model's maximum token limit");
} else {

auto accumulated_tuples_tokens = 0u;
auto batch_tuples = nlohmann::json::array();
auto batch_size = tuples.size();
int start_index = 0;

do {
while (accumulated_tuples_tokens < available_tokens && start_index < tuples.size() &&
batch_tuples.size() < batch_size) {
auto num_tokens = Tiktoken::GetNumTokens(tuples[start_index].dump());
if (accumulated_tuples_tokens + num_tokens > available_tokens) {
break;
}
batch_tuples.push_back(tuples[start_index]);
accumulated_tuples_tokens += num_tokens;
start_index++;
}

nlohmann::json response;
try {
response = Complete(batch_tuples, user_prompt, llm_template, model_details);
} catch (const LengthExceededError &e) {
batch_tuples.clear();
accumulated_tuples_tokens = 0;
auto new_batch_size = int(batch_size * 0.1);
batch_size = 1 ? new_batch_size == 0 : new_batch_size;
accumulated_tuples_tokens = 0;
start_index = 0;
continue;
}
auto output_tokens_per_tuple = Tiktoken::GetNumTokens(response.dump()) / batch_tuples.size();

batch_size = model_details.max_output_tokens / output_tokens_per_tuple;
batch_tuples.clear();
accumulated_tuples_tokens = 0;

for (const auto &tuple : response) {
responses.push_back(tuple);
}

} while (start_index < tuples.size());
}

return responses;
}

} // namespace core
} // namespace flockmtl
Loading

0 comments on commit 25958bb

Please sign in to comment.