diff --git a/eval/eval.py b/eval/eval.py index 43a3521..785fe6f 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -168,8 +168,17 @@ def get_all_minimal_queries(query: str) -> "list[str]": left = query[:start] column_str = ", ".join(column_tuple) right = query[end + 1 :] + g_column_str = column_str + len_tuple = len(column_tuple) + # check if the column str contains columns defined with alias AS ... + if " as " in column_str.lower() and "group by {}" in right.lower(): + g_column_str = "" + for i, column in enumerate(column_tuple): + as_index = column.lower().find(" as ") + 4 + g_column_str += column[as_index:] if as_index - 3 else column + g_column_str += ", " if i != len_tuple - 1 else "" # change group by size dynamically if necessary - right = right.replace("GROUP BY {}", f"GROUP BY {column_str}") + right = right.replace("GROUP BY {}", f"GROUP BY {g_column_str}") result_queries.append(left + column_str + right) return result_queries diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index aa9f571..0779bc8 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -15,6 +15,7 @@ from utils.llm import chat_anthropic import json + def generate_prompt( prompt_file, question,