Skip to content

Commit

Permalink
fix llm_rerank index accessing and gpt-4o-mini ranked list issue (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorbanianas authored Oct 31, 2024
1 parent ec40d84 commit 0a9f966
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 36 deletions.
53 changes: 24 additions & 29 deletions src/core/functions/aggregate/llm_rerank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ nlohmann::json LlmReranker::SlidingWindowRerank(nlohmann::json &tuples) {
auto batch_size = 0u;
auto window_tuples = nlohmann::json::array();
auto start_index = num_tuples - 1;
auto reranked_tuples = nlohmann::json::array();
auto half_batch = 0;
auto next_tuples = nlohmann::json::array();

do {
window_tuples.clear();
window_tuples = std::move(next_tuples);
next_tuples.clear();
batch_size = half_batch;
accumulated_rows_tokens = Tiktoken::GetNumTokens(window_tuples.dump());
while (available_tokens - accumulated_rows_tokens > 0 && start_index >= 0) {
auto num_tokens = Tiktoken::GetNumTokens(tuples[start_index].dump());
if (accumulated_rows_tokens + num_tokens > available_tokens) {
Expand All @@ -51,32 +57,24 @@ nlohmann::json LlmReranker::SlidingWindowRerank(nlohmann::json &tuples) {
start_index--;
}

auto ranked_indices = LlmRerankWithSlidingWindow(window_tuples);

auto half_batch = batch_size / 2;
auto next_tuples = nlohmann::json::array();
for (auto i = 0; i < batch_size; i++) {
if (i < half_batch) {
next_tuples.push_back(window_tuples[i]);
} else {
reranked_tuples.push_back(window_tuples[i]);
}
auto indexed_tuples = nlohmann::json::array();
for (auto i = 0; i < window_tuples.size(); i++) {
auto indexed_tuple = nlohmann::json::object();
indexed_tuple["id"] = i;
indexed_tuple["content"] = window_tuples[i];
indexed_tuples.push_back(indexed_tuple);
}

window_tuples.clear();
window_tuples = std::move(next_tuples);
batch_size = half_batch;
accumulated_rows_tokens = Tiktoken::GetNumTokens(window_tuples.dump());
} while (start_index >= 0);

reranked_tuples.insert(reranked_tuples.end(), window_tuples.begin(), window_tuples.end());
auto ranked_indices = LlmRerankWithSlidingWindow(indexed_tuples);

nlohmann::json results;
for (auto i = num_tuples - 1; i >= num_tuples / 2; i--) {
results.push_back(reranked_tuples[i]["content"]);
}
half_batch = batch_size / 2;
next_tuples = nlohmann::json::array();
for (auto i = 0; i < half_batch; i++) {
next_tuples.push_back(window_tuples[ranked_indices[i]]);
}
} while (start_index >= 0);

return results;
return next_tuples;
}

int LlmReranker::CalculateFixedTokens() const {
Expand All @@ -86,7 +84,7 @@ int LlmReranker::CalculateFixedTokens() const {
return num_tokens_meta_and_search_query;
}

nlohmann::json LlmReranker::LlmRerankWithSlidingWindow(const nlohmann::json &tuples) {
vector<int> LlmReranker::LlmRerankWithSlidingWindow(const nlohmann::json &tuples) {
inja::Environment env;
nlohmann::json data;
data["tuples"] = tuples;
Expand All @@ -95,7 +93,7 @@ nlohmann::json LlmReranker::LlmRerankWithSlidingWindow(const nlohmann::json &tup

auto response = ModelManager::CallComplete(prompt, LlmAggOperation::model_details);

return response["ranking"];
return response["ranking"].get<vector<int>>();
};

void LlmAggOperation::RerankerFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
Expand All @@ -111,10 +109,7 @@ void LlmAggOperation::RerankerFinalize(Vector &states, AggregateInputData &aggr_

auto tuples_with_ids = nlohmann::json::array();
for (auto i = 0; i < state->value.size(); i++) {
auto tuple_with_id = nlohmann::json::object();
tuple_with_id["id"] = i;
tuple_with_id["content"] = state->value[i];
tuples_with_ids.push_back(tuple_with_id);
tuples_with_ids.push_back(state->value[i]);
}

LlmReranker llm_reranker(LlmAggOperation::model_details.model, Config::default_max_tokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class LlmReranker {

int CalculateFixedTokens() const;

nlohmann::json LlmRerankWithSlidingWindow(const nlohmann::json &tuples);
vector<int> LlmRerankWithSlidingWindow(const nlohmann::json &tuples);
};

} // namespace core
Expand Down
14 changes: 8 additions & 6 deletions src/include/templates/llm_rerank_prompt_template.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#define LLM_RERANK_PROMPT_TEMPLATE_H

constexpr auto llm_rerank_prompt_template = R"(
You are RankLLM, an intelligent assistant that can rank passages based on their relevancy to the query.
You are RankLLM, an intelligent assistant that ranks passages based on their relevance to a given query.
I will provide you with some tuples, each indicated by a numerical identifier []. Rank the passages based on their relevance to the search query:
Below is a set of tuples, each with a unique numerical identifier. Rank all of these tuples based on their relevance to the provided search query, and return a JSON list that includes all tuple IDs in descending order of relevance.
Tuples:
Expand All @@ -14,14 +14,16 @@ I will provide you with some tuples, each indicated by a numerical identifier []
Search Query: {{search_query}}
Rank the tuples above based on their relevance to the search query. All the passages should be included and listed using identifiers, in descending order of relevance. The output format should be in JSON list [id_1, ..., id_n], e.g., [22, 33, ..., 3], Only respond with the ranking results, do not say any word or explain.
Response Format:
Return your response in JSON format **only** as follows, listing all tuple IDs in descending order of relevance:
```json
{
"ranking": [id_1, ..., id_n]
"ranking": [id1, id2, id3, ...]
}
```
No additional text or explanation should be provided; only return the ranking results with all tuple IDs included.
MAKE SURE YOU RETURN THE FULL LIST. EACH ITEM HAS TO BE RETURNED IN THE LIST.
)";

#endif // LLM_RERANK_PROMPT_TEMPLATE_H

0 comments on commit 0a9f966

Please sign in to comment.