Skip to content

Commit

Permalink
change the aggregate functions' interface for prompt details
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas committed Nov 14, 2024
1 parent 731f472 commit 757bbca
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 18 deletions.
28 changes: 21 additions & 7 deletions src/core/functions/aggregate/llm_agg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,18 @@ void LlmAggOperation::Initialize(const AggregateFunction &, data_ptr_t state_p)

void LlmAggOperation::Operation(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states,
idx_t count) {
search_query = inputs[0].GetValue(0).ToString();

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
if (inputs[0].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for model details");
}

auto model_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
auto model_details_json = CastVectorOfStructsToJson(inputs[0], 1)[0];
LlmAggOperation::model_details = ModelManager::CreateModelDetails(CoreModule::GetConnection(), model_details_json);

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt details");
}
auto prompt_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
search_query = CreatePromptDetails(CoreModule::GetConnection(), prompt_details_json).prompt;

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt inputs");
}
Expand Down Expand Up @@ -174,10 +177,21 @@ void LlmAggOperation::FirstOrLastFinalize<FirstOrLast::FIRST>(Vector &states, Ag

void LlmAggOperation::SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count,
data_ptr_t state_p, idx_t count) {
search_query = inputs[0].GetValue(0).ToString();
auto model_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
if (inputs[0].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for model details");
}
auto model_details_json = CastVectorOfStructsToJson(inputs[0], 1)[0];
LlmAggOperation::model_details = ModelManager::CreateModelDetails(CoreModule::GetConnection(), model_details_json);

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt details");
}
auto prompt_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
search_query = CreatePromptDetails(CoreModule::GetConnection(), prompt_details_json).prompt;

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt inputs");
}
auto tuples = CastVectorOfStructsToJson(inputs[2], count);

auto state_map_p = reinterpret_cast<LlmAggState *>(state_p);
Expand Down
2 changes: 1 addition & 1 deletion src/core/functions/aggregate/llm_first.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace core {

void CoreAggregateFunctions::RegisterLlmFirstFunction(DatabaseInstance &db) {
auto string_concat =
AggregateFunction("llm_first", {LogicalType::VARCHAR, LogicalType::ANY, LogicalType::ANY}, LogicalType::JSON(),
AggregateFunction("llm_first", {LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::JSON(),
AggregateFunction::StateSize<LlmAggState>, LlmAggOperation::Initialize,
LlmAggOperation::Operation, LlmAggOperation::Combine,
LlmAggOperation::FirstOrLastFinalize<FirstOrLast::FIRST>, LlmAggOperation::SimpleUpdate);
Expand Down
2 changes: 1 addition & 1 deletion src/core/functions/aggregate/llm_last.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace core {

void CoreAggregateFunctions::RegisterLlmLastFunction(DatabaseInstance &db) {
auto string_concat =
AggregateFunction("llm_last", {LogicalType::VARCHAR, LogicalType::ANY, LogicalType::ANY}, LogicalType::JSON(),
AggregateFunction("llm_last", {LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::JSON(),
AggregateFunction::StateSize<LlmAggState>, LlmAggOperation::Initialize,
LlmAggOperation::Operation, LlmAggOperation::Combine,
LlmAggOperation::FirstOrLastFinalize<FirstOrLast::LAST>, LlmAggOperation::SimpleUpdate);
Expand Down
30 changes: 22 additions & 8 deletions src/core/functions/aggregate/llm_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,19 @@ struct LlmReduceOperation {

static void Operation(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states,
idx_t count) {
reduce_query = inputs[0].GetValue(0).ToString();

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
if (inputs[0].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for model details");
}

auto model_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
auto model_details_json = CastVectorOfStructsToJson(inputs[0], 1)[0];
LlmReduceOperation::model_details =
ModelManager::CreateModelDetails(CoreModule::GetConnection(), model_details_json);

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt details");
}
auto prompt_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
reduce_query = CreatePromptDetails(CoreModule::GetConnection(), prompt_details_json).prompt;

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt inputs");
}
Expand Down Expand Up @@ -160,11 +163,22 @@ struct LlmReduceOperation {

static void SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count,
data_ptr_t state_p, idx_t count) {
reduce_query = inputs[0].GetValue(0).ToString();
auto model_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
if (inputs[0].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for model details");
}
auto model_details_json = CastVectorOfStructsToJson(inputs[0], 1)[0];
LlmReduceOperation::model_details =
ModelManager::CreateModelDetails(CoreModule::GetConnection(), model_details_json);

if (inputs[1].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt details");
}
auto prompt_details_json = CastVectorOfStructsToJson(inputs[1], 1)[0];
reduce_query = CreatePromptDetails(CoreModule::GetConnection(), prompt_details_json).prompt;

if (inputs[2].GetType().id() != LogicalTypeId::STRUCT) {
throw std::runtime_error("Expected a struct type for prompt inputs");
}
auto tuples = CastVectorOfStructsToJson(inputs[2], count);

auto state_map_p = reinterpret_cast<LlmAggState *>(state_p);
Expand All @@ -187,7 +201,7 @@ std::unordered_map<void *, std::shared_ptr<LlmAggState>> LlmReduceOperation::sta

void CoreAggregateFunctions::RegisterLlmReduceFunction(DatabaseInstance &db) {
auto string_concat = AggregateFunction(
"llm_reduce", {LogicalType::VARCHAR, LogicalType::ANY, LogicalType::ANY}, LogicalType::VARCHAR,
"llm_reduce", {LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::VARCHAR,
AggregateFunction::StateSize<LlmAggState>, LlmReduceOperation::Initialize, LlmReduceOperation::Operation,
LlmReduceOperation::Combine, LlmReduceOperation::Finalize, LlmReduceOperation::SimpleUpdate);

Expand Down
2 changes: 1 addition & 1 deletion src/core/functions/aggregate/llm_rerank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void LlmAggOperation::RerankerFinalize(Vector &states, AggregateInputData &aggr_

void CoreAggregateFunctions::RegisterLlmRerankFunction(DatabaseInstance &db) {
auto string_concat = AggregateFunction(
"llm_rerank", {LogicalType::VARCHAR, LogicalType::ANY, LogicalType::ANY}, LogicalType::JSON(),
"llm_rerank", {LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::JSON(),
AggregateFunction::StateSize<LlmAggState>, LlmAggOperation::Initialize, LlmAggOperation::Operation,
LlmAggOperation::Combine, LlmAggOperation::RerankerFinalize, LlmAggOperation::SimpleUpdate);

Expand Down

0 comments on commit 757bbca

Please sign in to comment.