From 8d3712d19715aa46a57f32a9452e1f8e772acc10 Mon Sep 17 00:00:00 2001 From: codestory Date: Wed, 29 Jan 2025 18:58:23 +0000 Subject: [PATCH 1/2] feat: sync local changes --- runners/anthropic_runner.py | 367 +++++++++------------- runners/api_runner.py | 598 +++++++++++++++++------------------- runners/base_runner.py | 173 +++++++++++ runners/bedrock_runner.py | 324 +++++++++---------- runners/deepseek_runner.py | 338 ++++++++++---------- runners/gemini_runner.py | 373 ++++++++++------------ runners/hf_runner.py | 440 +++++++++++++------------- runners/llama_cpp_runner.py | 265 ++++++++-------- runners/mistral_runner.py | 439 +++++++++++++------------- runners/mlx_runner.py | 269 ++++++++-------- runners/openai_runner.py | 386 +++++++++-------------- runners/together_runner.py | 339 ++++++++++---------- runners/vllm_runner.py | 412 +++++++++++++------------ 13 files changed, 2331 insertions(+), 2392 deletions(-) create mode 100644 runners/base_runner.py diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index 2081afb..34aa791 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -1,241 +1,154 @@ import os -from time import time -from concurrent.futures import ThreadPoolExecutor, as_completed - import pandas as pd -import sqlparse from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.dialects import convert_postgres_ddl_to_dialect -from utils.gen_prompt import to_prompt_schema +from utils.llm import chat_anthropic from utils.questions import prepare_questions_df from utils.reporting import upload_results -from utils.llm import chat_anthropic - - -def generate_prompt( - prompt_file, - question, - db_name, - db_type, - instructions="", - k_shot_prompt="", - glossary="", - table_metadata_string="", - prev_invalid_sql="", - prev_error_msg="", - public_data=True, - shuffle=True, -): - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - with open(prompt_file, "r") as f: - prompt = f.read() - - if table_metadata_string == "": - md = dbs[db_name]["table_metadata"] - pruned_metadata_ddl = to_prompt_schema(md, shuffle) - pruned_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = pruned_metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string +from utils.creds import db_creds_all +from eval.eval import compare_query_results +from runners.base_runner import BaseRunner, generate_prompt - prompt = prompt.format( - user_question=question, - db_type=db_type, - instructions=instructions, - table_metadata_string=pruned_metadata_str, - k_shot_prompt=k_shot_prompt, - glossary=glossary, - prev_invalid_sql=prev_invalid_sql, - prev_error_msg=prev_error_msg, - ) - return prompt +class AnthropicRunner(BaseRunner): + def _call_llm(self, prompt, model_name, temperature=0.0): + """Call Anthropic API, handling both JSON and text prompts.""" + if isinstance(prompt, list): # JSON prompt format + messages = prompt + else: # Text prompt format + messages = [{"role": "user", "content": prompt}] + return chat_anthropic(messages=messages, model=model_name, temperature=temperature) -def process_row(row, model_name, args): - start_time = time() - prompt = generate_prompt( - prompt_file=args.prompt_file[0], - question=row["question"], - db_name=row["db_name"], - db_type=args.db_type, - instructions=row["instructions"], - k_shot_prompt=row["k_shot_prompt"], - glossary=row["glossary"], - table_metadata_string=row["table_metadata_string"], - prev_invalid_sql=row["prev_invalid_sql"], - prev_error_msg=row["prev_error_msg"], - public_data=not args.use_private_data, - shuffle=args.shuffle_metadata, - ) - messages = [{"role": "user", "content": prompt}] - try: - response = chat_anthropic(messages=messages, model=model_name, temperature=0.0) - generated_query = ( - response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() - ) + def _extract_query(self, response_content): + """Extract SQL query from response.""" try: - generated_query = sqlparse.format( - generated_query, reindent=True, keyword_case="upper" - ) + return response_content.split("```sql", 1)[-1].split("```", 1)[0].strip() except: - pass - return { - "query": generated_query, - "reason": "", - "err": "", - "latency_seconds": time() - start_time, - "tokens_used": response.input_tokens + response.output_tokens, - } - except Exception as e: - return { - "query": "", - "reason": "", - "err": f"GENERATION ERROR: {str(e)}", - "latency_seconds": time() - start_time, - "tokens_used": 0, - } - - -def run_anthropic_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - output_file_list = args.output_file - num_questions = args.num_questions - k_shot = args.k_shot - db_type = args.db_type - cot_table_alias = args.cot_table_alias - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - question_query_df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - input_rows = question_query_df.to_dict("records") - output_rows = [] - with ThreadPoolExecutor(args.parallel_threads) as executor: - futures = [] - for row in input_rows: - generated_query_fut = executor.submit( - process_row, - row=row, - model_name=args.model, - args=args, + # Fallback to extract anything that looks like SQL + return response_content.split(";")[0].strip() + ";" + + def run_eval(self, args): + """Anthropic-specific evaluation logic.""" + questions_file_list = args.questions_file + prompt_file_list = args.prompt_file + output_file_list = args.output_file + + for questions_file, prompt_file, output_file in zip( + questions_file_list, prompt_file_list, output_file_list + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + question_query_df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + ) + input_rows = question_query_df.to_dict("records") + output_rows = [] + + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in input_rows: + generated_query_fut = executor.submit( + self.process_row, + row=row, + model_name=args.model, + args=args, + ) + futures.append(generated_query_fut) + + total_tried = 0 + total_correct = 0 + for f in (pbar := tqdm(as_completed(futures), total=len(futures))): + total_tried += 1 + i = futures.index(f) + row = input_rows[i] + result_dict = f.result() + query_gen = result_dict["query"] + reason = result_dict["reason"] + err = result_dict["err"] + # save custom metrics + if "latency_seconds" in result_dict: + row["latency_seconds"] = result_dict["latency_seconds"] + if "tokens_used" in result_dict: + row["tokens_used"] = result_dict["tokens_used"] + row["generated_query"] = query_gen + row["reason"] = reason + row["error_msg"] = err + # save failures into relevant columns in the dataframe + if "GENERATION ERROR" in err: + row["error_query_gen"] = 1 + else: + expected_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + try: + is_correct = compare_query_results( + query_gold=expected_query, + query_gen=query_gen, + db_name=db_name, + db_type=db_type, + question=row["question"], + query_category=row["query_category"], + db_creds=db_creds_all[db_type], + ) + if is_correct: + total_correct += 1 + row["is_correct"] = 1 + row["error_msg"] = "" + else: + row["is_correct"] = 0 + row["error_msg"] = "INCORRECT RESULTS" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"EXECUTION ERROR: {str(e)}" + output_rows.append(row) + pbar.set_description( + f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" + ) + + # save results to csv + output_df = pd.DataFrame(output_rows) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + if "prompt" in output_df.columns: + del output_df["prompt"] + # get num rows, mean correct, mean error_db_exec for each query_category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("is_correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), ) - futures.append(generated_query_fut) - - total_tried = 0 - total_correct = 0 - for f in (pbar := tqdm(as_completed(futures), total=len(futures))): - total_tried += 1 - i = futures.index(f) - row = input_rows[i] - result_dict = f.result() - query_gen = result_dict["query"] - reason = result_dict["reason"] - err = result_dict["err"] - # save custom metrics - if "latency_seconds" in result_dict: - row["latency_seconds"] = result_dict["latency_seconds"] - if "tokens_used" in result_dict: - row["tokens_used"] = result_dict["tokens_used"] - row["generated_query"] = query_gen - row["reason"] = reason - row["error_msg"] = err - # save failures into relevant columns in the dataframe - if "GENERATION ERROR" in err: - row["error_query_gen"] = 1 - elif "TIMEOUT" in err: - row["timeout"] = 1 - else: - expected_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - try: - is_correct = compare_query_results( - query_gold=expected_query, - query_gen=query_gen, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[db_type], - question=row["question"], - query_category=row["query_category"], - decimal_points=args.decimal_points, - ) - if is_correct: - total_correct += 1 - row["is_correct"] = 1 - row["error_msg"] = "" - else: - row["is_correct"] = 0 - row["error_msg"] = "INCORRECT RESULTS" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"EXECUTION ERROR: {str(e)}" - output_rows.append(row) - pbar.set_description( - f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" + .reset_index() + ) + print(agg_stats) + # get directory of output_file and create if not exist + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_df.to_csv(output_file, index=False, float_format="%.2f") + + # get average rate of correct results + avg_subset = output_df["correct"].sum() / len(output_df) + print(f"Average correct rate: {avg_subset:.2f}") + + results = output_df.to_dict("records") + + # upload results + with open(prompt_file, "r") as f: + prompt = f.read() + if args.upload_url is not None: + upload_results( + results=results, + url=args.upload_url, + runner_type="anthropic", + prompt=prompt, + args=args, ) - # save results to csv - output_df = pd.DataFrame(output_rows) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - output_df.to_csv(output_file, index=False, float_format="%.2f") - - # get average rate of correct results - avg_subset = output_df["is_correct"].sum() / len(output_df) - print(f"Average correct rate: {avg_subset:.2f}") - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="anthropic", - prompt=prompt, - args=args, - ) +def run_anthropic_eval(args): + runner = AnthropicRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/api_runner.py b/runners/api_runner.py index 6b2afdf..7cf0507 100644 --- a/runners/api_runner.py +++ b/runners/api_runner.py @@ -1,356 +1,322 @@ -import json import os +import json from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Optional -from eval.eval import compare_query_results import pandas as pd +import sqlparse +import re +import requests +from tqdm import tqdm +from time import time + +from runners.base_runner import BaseRunner from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time -import requests +from eval.eval import compare_query_results from utils.reporting import upload_results -import sqlparse -import re def clean_generated_query(query: str): - """ - Clean up the generated query by - - formatting the query using sqlparse - - fixing common problems in LLM-powered query generation with post-processing heuristics - - KNOWN ISSUES: the division fix will only work with Postgres/Redshift/Snowflake/Databricks. It might not work with other databases. - """ - + """Clean up the generated query with post-processing heuristics.""" query = sqlparse.format(query, reindent_aligned=True) - - # if the string `< =` is present, replace it with `<=`. Similarly for `> =` and `>=` query = query.replace("< =", "<=").replace("> =", ">=") - - # if the string ` / NULLIF (` is present, replace it with `/ NULLIF ( 1.0 * `. - # This is a fix for ensuring that the denominator is always a float in division operations. query = query.replace("/ NULLIF (", "/ NULLIF (1.0 * ") - - # remove extra spaces around brackets especially for MySQL - query = re.sub(r"\s*\(\s*", "(", query) # Remove spaces before and after '(' - query = re.sub(r"\s*\)", ")", query) # Remove spaces before ')' - + query = re.sub(r"\s*\(\s*", "(", query) + query = re.sub(r"\s*\)", ")", query) return query -def mk_vllm_json( - prompt, num_beams, logprobs=False, sql_lora_path=None, sql_lora_name=None -): - payload = { - "prompt": prompt, - "n": 1, - "use_beam_search": num_beams > 1, - "best_of": num_beams, - "temperature": 0, - "stop": [";", "```"], - "max_tokens": 4000, - "seed": 42, - "sql_lora_path": sql_lora_path, - "sql_lora_name": sql_lora_name, - } - if logprobs: - payload["logprobs"] = 2 - return payload - - -def mk_tgi_json(prompt, num_beams): - # see swagger docs for /generate for the full list of parameters: - # https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/generate - return { - "inputs": prompt, - "parameters": { - "best_of": num_beams, - "do_sample": num_beams > 1, - "return_full_text": False, - "max_new_tokens": 1024, - }, - } +class APIRunner(BaseRunner): + def __init__(self): + super().__init__() + self.api_url = None + self.api_type = None - -def process_row( - row, - api_url: str, - api_type: str, - num_beams: int, - decimal_points: int, - logprobs: bool = False, - sql_lora_path: Optional[str] = None, - sql_lora_name: Optional[str] = None, -): - start_time = time() - if api_type == "tgi": - json_data = mk_tgi_json(row["prompt"], num_beams) - elif api_type == "vllm": - json_data = mk_vllm_json( - row["prompt"], num_beams, logprobs, sql_lora_path, sql_lora_name - ) - else: - # add any custom JSON data here, e.g. for a custom API - json_data = { - "prompt": row["prompt"], + def _mk_vllm_json(self, prompt, num_beams, logprobs=False, sql_lora_path=None, sql_lora_name=None): + payload = { + "prompt": prompt, "n": 1, "use_beam_search": num_beams > 1, "best_of": num_beams, "temperature": 0, "stop": [";", "```"], "max_tokens": 4000, + "seed": 42, + "sql_lora_path": sql_lora_path, + "sql_lora_name": sql_lora_name, } - try: - r = requests.post( - api_url, - json=json_data, - timeout=200, - ) - except: - row["generated_query"] = "" - row["exact_match"] = 0 - row["correct"] = 0 - row["error_db_exec"] = 1 - row["error_msg"] = "API TIMEOUT" - row["tokens_used"] = None if logprobs: - row["logprobs"] = [] + payload["logprobs"] = 2 + return payload + + def _mk_tgi_json(self, prompt, num_beams): + return { + "inputs": prompt, + "parameters": { + "best_of": num_beams, + "do_sample": num_beams > 1, + "return_full_text": False, + "max_new_tokens": 1024, + }, + } - return row - end_time = time() - logprobs = [] - if api_type == "tgi": - # we do not return the original prompt in tgi - try: - generated_query = r.json()["generated_text"] - except KeyError: - print(r.json()) - generated_query = "" - elif "[SQL]" not in row["prompt"]: - generated_query = ( - r.json()["text"][0] - .split("```sql")[-1] - .split("```")[0] - .split(";")[0] - .strip() - + ";" - ) - else: - generated_query = r.json()["text"][0] - if "[SQL]" in generated_query: - generated_query = generated_query.split("[SQL]", 1)[1].strip() + def _call_llm(self, prompt, model_name, temperature=0.0): + """Call API endpoint.""" + # model_name is unused but kept for BaseRunner compatibility + start_time = time() + json_data = None + + if self.api_type == "tgi": + json_data = self._mk_tgi_json(prompt, self.num_beams) + elif self.api_type == "vllm": + json_data = self._mk_vllm_json( + prompt, self.num_beams, self.logprobs, self.sql_lora_path, self.sql_lora_name + ) else: - generated_query = generated_query.strip() - - # clean up the generated query - generated_query = clean_generated_query(generated_query) - - if "logprobs" in r.json(): - logprobs = r.json()["logprobs"] - - row["generated_query"] = generated_query - logprobs_display = [] - for item in logprobs: - probs = list(item.values()) - probs_to_append = {} - for prob in probs: - rank = prob["rank"] - logprob = prob["logprob"] - token = prob["decoded_token"] - probs_to_append.update( - { - f"rank_{rank}_token": token, - f"rank_{rank}_logprob": logprob, - f"rank_{rank}_prob": 10**logprob, - } + json_data = { + "prompt": prompt, + "n": 1, + "use_beam_search": self.num_beams > 1, + "best_of": self.num_beams, + "temperature": 0, + "stop": [";", "```"], + "max_tokens": 4000, + } + + try: + response = requests.post( + self.api_url, + json=json_data, + timeout=200, ) - - probs_to_append["prob_diff"] = ( - probs_to_append["rank_1_prob"] - probs_to_append["rank_2_prob"] - ) - logprobs_display.append(probs_to_append) - row["logprobs"] = logprobs_display - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all.get(row["db_type"], {}), - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - return row - - -def run_api_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - api_url = args.api_url - api_type = args.api_type - output_file_list = args.output_file - k_shot = args.k_shot - num_beams = args.num_beams - max_workers = args.parallel_threads - db_type = args.db_type - decimal_points = args.decimal_points - logprobs = args.logprobs - cot_table_alias = args.cot_table_alias - sql_lora_path = args.adapter if args.adapter else None - sql_lora_name = args.adapter_name if args.adapter_name else None - run_name = args.run_name if args.run_name else None - if sql_lora_path: - print("Using LoRA adapter at:", sql_lora_path) - if logprobs: - # check that the eval-visualizer/public directory exists - if not os.path.exists("./eval-visualizer"): - # thorow error - raise Exception( - "The eval-visualizer directory does not exist. Please clone it with `git clone https://github.com/defog-ai/eval-visualizer/` before running sql-eval with the --logprobs flag." + response.raise_for_status() + return response.json() + except Exception as e: + raise Exception(f"API ERROR: {str(e)}") + finally: + self.request_time = time() - start_time + + def _extract_query(self, response_json, prompt): + """Extract SQL query from API response.""" + if self.api_type == "tgi": + try: + return response_json["generated_text"] + except KeyError: + print(response_json) + return "" + elif "[SQL]" not in prompt: + return ( + response_json["text"][0] + .split("```sql")[-1] + .split("```")[0] + .split(";")[0] + .strip() + + ";" + ) + else: + generated_text = response_json["text"][0] + if "[SQL]" in generated_text: + generated_text = generated_text.split("[SQL]", 1)[1].strip() + return generated_text.strip() + + def process_row(self, row, model_name, args): + """API-specific row processing.""" + # Set API-specific attributes + self.api_url = args.api_url + self.api_type = args.api_type + self.num_beams = args.num_beams + self.logprobs = args.logprobs + self.sql_lora_path = args.adapter + self.sql_lora_name = args.adapter_name + + try: + prompt = generate_prompt( + prompt_file=args.prompt_file[0], + question=row["question"], + db_name=row["db_name"], + db_type=args.db_type, + instructions=row["instructions"], + k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], + prev_invalid_sql=row["prev_invalid_sql"], + prev_error_msg=row["prev_error_msg"], + question_0=row.get("question_0", ""), + query_0=row.get("query_0", ""), + question_1=row.get("question_1", ""), + query_1=row.get("query_1", ""), + cot_instructions=row.get("cot_instructions", ""), + cot_pregen=row.get("cot_pregen", False), + public_data=not args.use_private_data, + columns_to_keep=args.num_columns, + shuffle_metadata=args.shuffle_metadata, + table_aliases=row.get("table_aliases", ""), ) - if not os.path.exists("./eval-visualizer/public"): - os.makedirs("./eval-visualizer/public") - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - # create a prompt for each question - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], - public_data, - args.num_columns, - args.shuffle_metadata, - row["table_aliases"], - ), - axis=1, - ) - - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append( - executor.submit( - process_row, - row, - api_url, - api_type, - num_beams, - decimal_points, - logprobs, - sql_lora_path, - sql_lora_name, + response_json = self._call_llm(prompt, None) # model_name unused for API + generated_query = self._extract_query(response_json, prompt) + generated_query = clean_generated_query(generated_query) + + # Handle logprobs if present + logprobs_display = [] + if "logprobs" in response_json: + for item in response_json["logprobs"]: + probs = list(item.values()) + probs_to_append = {} + for prob in probs: + rank = prob["rank"] + logprob = prob["logprob"] + token = prob["decoded_token"] + probs_to_append.update({ + f"rank_{rank}_token": token, + f"rank_{rank}_logprob": logprob, + f"rank_{rank}_prob": 10**logprob, + }) + probs_to_append["prob_diff"] = ( + probs_to_append["rank_1_prob"] - probs_to_append["rank_2_prob"] ) + logprobs_display.append(probs_to_append) + + # Prepare result + result = { + "generated_query": generated_query, + "latency_seconds": self.request_time, + "tokens_used": None, # API doesn't provide token counts + "logprobs": logprobs_display if self.logprobs else None + } + + # Run comparison + try: + exact_match, correct = compare_query_results( + query_gold=row["query"], + query_gen=generated_query, + db_name=row["db_name"], + db_type=row["db_type"], + db_creds=db_creds_all.get(row["db_type"], {}), + question=row["question"], + query_category=row["query_category"], + table_metadata_string=row["table_metadata_string"], + decimal_points=args.decimal_points, ) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - - output_df = pd.DataFrame(output_rows) - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - results = output_df.to_dict("records") - - if logprobs: + result["exact_match"] = int(exact_match) + result["correct"] = int(correct) + result["error_msg"] = "" + except Exception as e: + result["error_db_exec"] = 1 + result["error_msg"] = f"QUERY EXECUTION ERROR: {str(e)}" + + return {**row, **result} + + except Exception as e: + return { + **row, + "generated_query": "", + "exact_match": 0, + "correct": 0, + "error_db_exec": 1, + "error_msg": f"API ERROR: {str(e)}", + "tokens_used": None, + "latency_seconds": self.request_time if hasattr(self, 'request_time') else None, + "logprobs": [] if self.logprobs else None + } + + def run_eval(self, args): + """API-specific evaluation logic.""" + # Validate API requirements + if not args.api_url: + raise ValueError("API URL must be provided for API runner") + if not args.api_type or args.api_type not in ["vllm", "tgi"]: + raise ValueError("API type must be one of 'vllm', 'tgi'") + + # Set up logprobs if needed + if args.logprobs: + if not os.path.exists("./eval-visualizer"): + raise Exception( + "The eval-visualizer directory does not exist. Please clone it with " + "`git clone https://github.com/defog-ai/eval-visualizer/` before running " + "sql-eval with the --logprobs flag." + ) + if not os.path.exists("./eval-visualizer/public"): + os.makedirs("./eval-visualizer/public") + + for questions_file, prompt_file, output_file in zip( + args.questions_file, args.prompt_file, args.output_file + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") print( - f"Writing logprobs to JSON file at eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}" + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias ) - with open( - f"./eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}", - "w", - ) as f: - json.dump(results, f) - del output_df["prompt"] - try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) + # Process all rows + output_rows = [] + total_tried = total_correct = 0 + + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in df.to_dict("records"): + futures.append( + executor.submit(self.process_row, row, None, args) + ) - # upload results - # with open(prompt_file, "r") as f: - # prompt = f.read() + with tqdm(as_completed(futures), total=len(futures)) as pbar: + for f in pbar: + row = f.result() + output_rows.append(row) + if row.get("correct", 0): + total_correct += 1 + total_tried += 1 + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + ) + + # Save results + output_df = pd.DataFrame(output_rows) + + # Handle logprobs if needed + if args.logprobs: + print( + f"Writing logprobs to JSON file at eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}" + ) + with open( + f"./eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}", + "w", + ) as f: + json.dump(output_rows, f) + + if "prompt" in output_df.columns: + del output_df["prompt"] + + print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + + # Save to file + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + try: + output_df.to_csv(output_file, index=False, float_format="%.2f") + except: + output_df.to_pickle(output_file) + print(f"Saved results to {output_file}") + + # Upload results if needed + if args.upload_url is not None: + run_name = args.run_name or output_file.split("/")[-1].replace(".csv", "") + upload_results( + results=output_rows, + url=args.upload_url, + runner_type="api_runner", + args=args, + run_name=run_name, + ) - if args.run_name is None: - run_name = output_file.split("/")[-1].replace(".csv", "") - print( - "Run name not provided. Using a output filename for run name:", run_name - ) - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - args=args, - run_name=run_name, - ) +def run_api_eval(args): + runner = APIRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/base_runner.py b/runners/base_runner.py new file mode 100644 index 0000000..6315725 --- /dev/null +++ b/runners/base_runner.py @@ -0,0 +1,173 @@ +import os +from time import time +import pandas as pd +import sqlparse +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed + +from eval.eval import compare_query_results +from utils.creds import db_creds_all +from utils.dialects import convert_postgres_ddl_to_dialect +from utils.gen_prompt import to_prompt_schema +from utils.questions import prepare_questions_df +from utils.reporting import upload_results + + +def generate_prompt( + prompt_file, + question, + db_name, + db_type, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + prev_invalid_sql="", + prev_error_msg="", + public_data=True, + shuffle=True, +): + if public_data: + from defog_data.metadata import dbs + import defog_data.supplementary as sup + else: + from defog_data_private.metadata import dbs + import defog_data_private.supplementary as sup + + if table_metadata_string == "": + md = dbs[db_name]["table_metadata"] + pruned_metadata_ddl = to_prompt_schema(md, shuffle) + pruned_metadata_ddl = convert_postgres_ddl_to_dialect( + postgres_ddl=pruned_metadata_ddl, + to_dialect=db_type, + db_name=db_name, + ) + column_join = sup.columns_join.get(db_name, {}) + join_list = [] + for values in column_join.values(): + if isinstance(values[0], tuple): + for col_pair in values: + col_1, col_2 = col_pair + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + else: + col_1, col_2 = values[0] + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + if len(join_list) > 0: + join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) + else: + join_str = "" + pruned_metadata_str = pruned_metadata_ddl + join_str + else: + pruned_metadata_str = table_metadata_string + + return pruned_metadata_str + + +def format_sql_query(generated_query): + try: + return sqlparse.format( + generated_query, reindent=True, keyword_case="upper" + ) + except: + return generated_query + + +class BaseRunner: + def __init__(self): + pass + + def _load_prompt(self, prompt_file): + """Load prompt from file. Override in subclass if format differs.""" + with open(prompt_file, "r") as f: + return f.read() + + def _format_prompt(self, prompt, **kwargs): + """Format the prompt with variables. Override in subclass if format differs.""" + return prompt.format(**kwargs) + + def _call_llm(self, messages, model_name, temperature=0.0): + """Call LLM API. Must be implemented in subclass.""" + raise NotImplementedError("Subclass must implement _call_llm") + + def _extract_query(self, response_content): + """Extract SQL query from response. Override in subclass if format differs.""" + return response_content.split("```sql", 1)[-1].split("```", 1)[0].strip() + + def process_row(self, row, model_name, args): + start_time = time() + try: + prompt = generate_prompt( + prompt_file=args.prompt_file[0], + question=row["question"], + db_name=row["db_name"], + db_type=args.db_type, + instructions=row["instructions"], + k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], + prev_invalid_sql=row["prev_invalid_sql"], + prev_error_msg=row["prev_error_msg"], + public_data=not args.use_private_data, + shuffle_metadata=args.shuffle_metadata, + ) + + response = self._call_llm(prompt, model_name) + generated_query = self._extract_query(response.content) + generated_query = format_sql_query(generated_query) + + return { + "query": generated_query, + "reason": "", + "err": "", + "latency_seconds": time() - start_time, + "tokens_used": response.input_tokens + response.output_tokens, + } + except Exception as e: + return { + "query": "", + "reason": "", + "err": f"GENERATION ERROR: {str(e)}", + "latency_seconds": time() - start_time, + "tokens_used": 0, + } + + def run_eval(self, args): + """Common evaluation logic.""" + questions_file_list = args.questions_file + prompt_file_list = args.prompt_file + output_file_list = args.output_file + model_name = args.model + + for questions_file, prompt_file, output_file in zip( + questions_file_list, prompt_file_list, output_file_list + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + question_query_df = prepare_questions_df( + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias + ) + input_rows = question_query_df.to_dict("records") + + results = [] + with ThreadPoolExecutor(max_workers=args.parallel_threads) as executor: + futures = [ + executor.submit(self.process_row, row, model_name, args) + for row in input_rows + ] + for future in tqdm( + as_completed(futures), total=len(futures), desc="Processing" + ): + result = future.result() + results.append(result) + + results_df = pd.DataFrame(results) + results_df.to_csv(output_file, index=False) + print(f"Results saved to {output_file}") \ No newline at end of file diff --git a/runners/bedrock_runner.py b/runners/bedrock_runner.py index 806402a..0c7a0cf 100644 --- a/runners/bedrock_runner.py +++ b/runners/bedrock_runner.py @@ -1,177 +1,189 @@ -import boto3 import json import os +import pandas as pd +from tqdm import tqdm +from time import time from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Optional +import boto3 -from eval.eval import compare_query_results -import pandas as pd +from runners.base_runner import BaseRunner from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time +from eval.eval import compare_query_results from utils.reporting import upload_results -bedrock = boto3.client(service_name="bedrock-runtime") - -def process_row(row, model_id, decimal_points): - start_time = time() +class BedrockRunner(BaseRunner): + def __init__(self): + super().__init__() + self.client = boto3.client(service_name="bedrock-runtime") - body = json.dumps( - { - "prompt": row["prompt"], + def _call_llm(self, prompt, model_name, temperature=0.0): + """Call AWS Bedrock API.""" + body = json.dumps({ + "prompt": prompt, "max_gen_len": 600, "temperature": 0, "top_p": 1, - } - ) - - accept = "application/json" - contentType = "application/json" - response = bedrock.invoke_model( - body=body, modelId=model_id, accept=accept, contentType=contentType - ) - model_response = json.loads(response["body"].read()) - - generated_query = model_response["generation"] - end_time = time() - - generated_query = ( - generated_query.split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";" - ) - - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=decimal_points, + }) + + accept = "application/json" + contentType = "application/json" + response = self.client.invoke_model( + body=body, modelId=model_name, accept=accept, contentType=contentType ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - return row - + model_response = json.loads(response["body"].read()) + + # Create a response object similar to other runners + class BedrockResponse: + def __init__(self, content): + self.content = content + self.input_tokens = 0 # Bedrock doesn't provide token counts this way + self.output_tokens = 0 + + return BedrockResponse(model_response["generation"]) + + def _extract_query(self, response_content): + """Extract SQL query from response.""" + return response_content.split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";" + + def process_row(self, row, model_name, args): + """Override process_row to use simple handling.""" + start_time = time() + try: + prompt = generate_prompt( + prompt_file=args.prompt_file[0], + question=row["question"], + db_name=row["db_name"], + db_type=args.db_type, + instructions=row["instructions"], + k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], + prev_invalid_sql=row["prev_invalid_sql"], + prev_error_msg=row["prev_error_msg"], + question_0=row.get("question_0", ""), + query_0=row.get("query_0", ""), + question_1=row.get("question_1", ""), + query_1=row.get("query_1", ""), + cot_instructions=row.get("cot_instructions", ""), + cot_pregen=row.get("cot_pregen", False), + public_data=not args.use_private_data, + columns_to_keep=args.num_columns, + shuffle_metadata=args.shuffle_metadata, + ) -def run_bedrock_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - output_file_list = args.output_file - k_shot = args.k_shot - max_workers = args.parallel_threads - db_type = args.db_type - decimal_points = args.decimal_points - model_id = args.model - cot_table_alias = args.cot_table_alias - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - # create a prompt for each question - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], - public_data, - args.num_columns, - args.shuffle_metadata, - ), - axis=1, - ) + response = self._call_llm(prompt, model_name) + generated_query = self._extract_query(response.content) + + result = { + "generated_query": generated_query, + "latency_seconds": time() - start_time, + "tokens_used": None # Bedrock doesn't provide token counts + } + + # Run comparison + try: + exact_match, correct = compare_query_results( + query_gold=row["query"], + query_gen=generated_query, + db_name=row["db_name"], + db_type=row["db_type"], + db_creds=db_creds_all[row["db_type"]], + question=row["question"], + query_category=row["query_category"], + table_metadata_string=row["table_metadata_string"], + decimal_points=args.decimal_points, + ) + result["exact_match"] = int(exact_match) + result["correct"] = int(correct) + result["error_msg"] = "" + except Exception as e: + result["error_db_exec"] = 1 + result["error_msg"] = f"QUERY EXECUTION ERROR: {str(e)}" + + return {**row, **result} + + except Exception as e: + return { + **row, + "generated_query": "", + "exact_match": 0, + "correct": 0, + "error_db_exec": 1, + "error_msg": f"PROCESSING ERROR: {str(e)}", + "latency_seconds": time() - start_time, + "tokens_used": None + } + + def run_eval(self, args): + """Bedrock-specific evaluation logic.""" + for questions_file, prompt_file, output_file in zip( + args.questions_file, args.prompt_file, args.output_file + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + ) - total_tried = 0 - total_correct = 0 - output_rows = [] + # Process all rows + output_rows = [] + total_tried = total_correct = 0 + + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in df.to_dict("records"): + futures.append( + executor.submit(self.process_row, row, args.model, args) + ) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append( - executor.submit(process_row, row, model_id, decimal_points) + with tqdm(as_completed(futures), total=len(futures)) as pbar: + for f in pbar: + row = f.result() + output_rows.append(row) + if row.get("correct", 0): + total_correct += 1 + total_tried += 1 + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + ) + + # Save results + output_df = pd.DataFrame(output_rows) + if "prompt" in output_df.columns: + del output_df["prompt"] + + print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + + # Save to file + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + try: + output_df.to_csv(output_file, index=False, float_format="%.2f") + except: + output_df.to_pickle(output_file) + print(f"Saved results to {output_file}") + + # Upload results if needed + if args.upload_url is not None: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="bedrock_runner", + prompt=prompt, + args=args, ) - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) +def run_bedrock_eval(args): + runner = BedrockRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/deepseek_runner.py b/runners/deepseek_runner.py index 323c0c1..e610be5 100644 --- a/runners/deepseek_runner.py +++ b/runners/deepseek_runner.py @@ -1,174 +1,196 @@ import os +import pandas as pd +from tqdm import tqdm +from time import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict +from openai import OpenAI -from eval.eval import compare_query_results -import pandas as pd +from runners.base_runner import BaseRunner from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time -from openai import OpenAI +from eval.eval import compare_query_results from utils.reporting import upload_results -client = OpenAI( - base_url="https://api.deepseek.com", api_key=os.environ.get("DEEPSEEK_API_KEY") -) - - -def process_row(row: Dict, model: str): - start_time = time() - messages = row["prompt"] - if model != "deepseek-reasoner": - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=800, - temperature=0.0, +class DeepseekRunner(BaseRunner): + def __init__(self): + super().__init__() + self.client = OpenAI( + base_url="https://api.deepseek.com", + api_key=os.environ.get("DEEPSEEK_API_KEY") ) - else: - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=800, - ) - content = response.choices[0].message.content - generated_query = content.replace("```sql", "").replace("```", "").strip() - end_time = time() - - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - return row + def _call_llm(self, messages, model_name, temperature=0.0): + """Call Deepseek API.""" + if model_name != "deepseek-reasoner": + response = self.client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=800, + temperature=temperature, + ) + else: + response = self.client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=800, + ) + + # Create a response object similar to other runners + class DeepseekResponse: + def __init__(self, content): + self.content = content + self.input_tokens = 0 # Deepseek doesn't provide token counts this way + self.output_tokens = 0 + + return DeepseekResponse(response.choices[0].message.content) + + def _extract_query(self, response_content): + """Extract SQL query from response with Deepseek-specific handling.""" + return response_content.replace("```sql", "").replace("```", "").strip() + + def process_row(self, row: Dict, model_name, args): + """Process a row using Deepseek.""" + start_time = time() + try: + # Deepseek uses OpenAI chat format + messages = generate_prompt( + prompt_file=args.prompt_file[0], + question=row["question"], + db_name=row["db_name"], + db_type=args.db_type, + instructions=row["instructions"], + k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], + prev_invalid_sql=row["prev_invalid_sql"], + prev_error_msg=row["prev_error_msg"], + question_0=row.get("question_0", ""), + query_0=row.get("query_0", ""), + question_1=row.get("question_1", ""), + query_1=row.get("query_1", ""), + cot_instructions=row.get("cot_instructions", ""), + cot_pregen=row.get("cot_pregen", False), + public_data=not args.use_private_data, + columns_to_keep=args.num_columns, + shuffle_metadata=args.shuffle_metadata, + table_aliases=row.get("table_aliases", ""), + ) -def run_deepseek_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - output_file_list = args.output_file - k_shot = args.k_shot - max_workers = args.parallel_threads - db_type = args.db_type - decimal_points = args.decimal_points - model = args.model - cot_table_alias = args.cot_table_alias - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - if not prompt_file.endswith(".json"): - raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - # create a prompt for each question - # note that the prompt for together ai uses the openai chat API - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], - public_data, - args.num_columns, - args.shuffle_metadata, - row["table_aliases"], - ), - axis=1, - ) + response = self._call_llm(messages, model_name) + generated_query = self._extract_query(response.content) + + result = { + "generated_query": generated_query, + "latency_seconds": time() - start_time, + "tokens_used": None # Deepseek doesn't provide token counts + } + + # Run comparison + try: + exact_match, correct = compare_query_results( + query_gold=row["query"], + query_gen=generated_query, + db_name=row["db_name"], + db_type=row["db_type"], + db_creds=db_creds_all[row["db_type"]], + question=row["question"], + query_category=row["query_category"], + table_metadata_string=row["table_metadata_string"], + ) + result["exact_match"] = int(exact_match) + result["correct"] = int(correct) + result["error_msg"] = "" + except Exception as e: + result["error_db_exec"] = 1 + result["error_msg"] = f"QUERY EXECUTION ERROR: {str(e)}" + + return {**row, **result} + + except Exception as e: + return { + **row, + "generated_query": "", + "exact_match": 0, + "correct": 0, + "error_db_exec": 1, + "error_msg": f"PROCESSING ERROR: {str(e)}", + "latency_seconds": time() - start_time, + "tokens_used": None + } + + def run_eval(self, args): + """Deepseek-specific evaluation logic.""" + if not args.prompt_file[0].endswith(".json"): + raise ValueError(f"Prompt file must be a JSON file. Got {args.prompt_file[0]}") + + for questions_file, prompt_file, output_file in zip( + args.questions_file, args.prompt_file, args.output_file + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + # Process all rows + output_rows = [] + total_tried = total_correct = 0 + + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in df.to_dict("records"): + futures.append( + executor.submit(self.process_row, row, args.model, args) ) - output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) + with tqdm(as_completed(futures), total=len(futures)) as pbar: + for f in pbar: + row = f.result() + output_rows.append(row) + if row.get("correct", 0): + total_correct += 1 + total_tried += 1 + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + ) + + # Save results + output_df = pd.DataFrame(output_rows) + if "prompt" in output_df.columns: + del output_df["prompt"] + + print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + + # Save to file + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + try: + output_df.to_csv(output_file, index=False, float_format="%.2f") + except: + output_df.to_pickle(output_file) + print(f"Saved results to {output_file}") + + # Upload results if needed + if args.upload_url is not None: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="deepseek_runner", + prompt=prompt, + args=args, + ) + + +def run_deepseek_eval(args): + runner = DeepseekRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/gemini_runner.py b/runners/gemini_runner.py index cd292c1..3ba16e9 100644 --- a/runners/gemini_runner.py +++ b/runners/gemini_runner.py @@ -1,230 +1,189 @@ import os -from time import time -from concurrent.futures import ThreadPoolExecutor, as_completed - import pandas as pd -import sqlparse from tqdm import tqdm +from time import time +import sqlparse +from concurrent.futures import ThreadPoolExecutor, as_completed -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.dialects import convert_postgres_ddl_to_dialect -from utils.gen_prompt import to_prompt_schema +from utils.llm import chat_gemini from utils.questions import prepare_questions_df from utils.reporting import upload_results -from utils.llm import chat_gemini +from utils.creds import db_creds_all +from eval.eval import compare_query_results +from runners.base_runner import BaseRunner, generate_prompt + +class GeminiRunner(BaseRunner): + def _call_llm(self, prompt, model_name, temperature=0.0): + """Call Gemini API, handling both JSON and text prompts.""" + if isinstance(prompt, list): # JSON prompt format + messages = prompt + else: # Text prompt format + messages = [{"role": "user", "content": prompt}] + return chat_gemini(messages=messages, model=model_name, temperature=temperature) -def generate_prompt( - prompt_file, - question, - db_name, - db_type, - instructions="", - k_shot_prompt="", - glossary="", - table_metadata_string="", - prev_invalid_sql="", - prev_error_msg="", - public_data=True, - shuffle=True, -): - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - # raise Exception("Replace this with your private data import") - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - with open(prompt_file, "r") as f: - prompt = f.read() - - if table_metadata_string == "": - md = dbs[db_name]["table_metadata"] - pruned_metadata_ddl = to_prompt_schema(md, shuffle) - pruned_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - # get join_str from column_join - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = pruned_metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - - prompt = prompt.format( - user_question=question, - db_type=db_type, - instructions=instructions, - table_metadata_string=pruned_metadata_str, - k_shot_prompt=k_shot_prompt, - glossary=glossary, - prev_invalid_sql=prev_invalid_sql, - prev_error_msg=prev_error_msg, - ) - return prompt - - -def process_row(row, model_name, args): - start_time = time() - messages = [{"role": "user", "content": row["prompt"]}] - try: - response = chat_gemini(messages=messages, model=model_name, temperature=0.0) - generated_query = ( - response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() - ) + def _extract_query(self, response_content): + """Extract SQL query from response.""" try: - generated_query = sqlparse.format( - generated_query, - strip_comments=True, - strip_whitespace=True, - keyword_case="upper", - ) + return response_content.split("```sql", 1)[-1].split("```", 1)[0].strip() except: - pass - row["generated_query"] = generated_query - row["latency_seconds"] = response.time - row["tokens_used"] = response.input_tokens + response.output_tokens - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"GENERATION ERROR: {e}" - return row - - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[db_type], - question=question, - query_category=query_category, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - return row - + # Fallback to extract anything that looks like SQL + try: + return sqlparse.format( + response_content.split(";")[0].strip() + ";", + strip_comments=True, + strip_whitespace=True, + keyword_case="upper", + ) + except: + return response_content.split(";")[0].strip() + ";" -def run_gemini_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - model_name = args.model - output_file_list = args.output_file - k_shot = args.k_shot - max_workers = args.parallel_threads - db_type = args.db_type - cot_table_alias = args.cot_table_alias - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - public_data, - args.shuffle_metadata, - ), - axis=1, - ) - - total_tried = 0 - total_correct = 0 - output_rows = [] - - print(f"Running evaluation using {model_name}...") - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model_name, args)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row.get("correct", 0): - total_correct += 1 - total_tried += 1 - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - - output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) + def process_row(self, row, model_name, args): + """Gemini-specific row processing logic.""" + start_time = time() try: + prompt = generate_prompt( + prompt_file=args.prompt_file[0], + question=row["question"], + db_name=row["db_name"], + db_type=args.db_type, + instructions=row["instructions"], + k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], + prev_invalid_sql=row["prev_invalid_sql"], + prev_error_msg=row["prev_error_msg"], + public_data=not args.use_private_data, + shuffle_metadata=args.shuffle_metadata, + ) + + response = self._call_llm(prompt, model_name) + generated_query = self._extract_query(response.content) + + row["generated_query"] = generated_query + row["latency_seconds"] = time() - start_time + row["tokens_used"] = response.input_tokens + response.output_tokens + + # Run comparison + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + decimal_points=args.decimal_points, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row + + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = 0 + return row + + def run_eval(self, args): + """Gemini-specific evaluation logic.""" + questions_file_list = args.questions_file + prompt_file_list = args.prompt_file + output_file_list = args.output_file + + for questions_file, prompt_file, output_file in zip( + questions_file_list, prompt_file_list, output_file_list + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + question_query_df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + ) + input_rows = question_query_df.to_dict("records") + output_rows = [] + + total_tried = 0 + total_correct = 0 + + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in input_rows: + futures.append(executor.submit( + self.process_row, + row=row, + model_name=args.model, + args=args + )) + + with tqdm(as_completed(futures), total=len(futures)) as pbar: + for f in pbar: + row = f.result() + output_rows.append(row) + if row.get("correct", 0): + total_correct += 1 + total_tried += 1 + pbar.set_description( + f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" + ) + + # save results to csv + output_df = pd.DataFrame(output_rows) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + if "prompt" in output_df.columns: + del output_df["prompt"] + + # get num rows, mean correct, mean error_db_exec for each query_category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # get directory of output_file and create if not exist + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) - results = output_df.to_dict("records") + # get average rate of correct results + avg_subset = output_df["correct"].sum() / len(output_df) + print(f"Average correct rate: {avg_subset:.2f}") - if args.upload_url is not None: + results = output_df.to_dict("records") + + # upload results with open(prompt_file, "r") as f: prompt = f.read() + if args.upload_url is not None: upload_results( results=results, url=args.upload_url, - runner_type="api_runner", + runner_type="gemini", prompt=prompt, args=args, ) + +def run_gemini_eval(args): + runner = GeminiRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/hf_runner.py b/runners/hf_runner.py index 9046a65..1e70b6b 100644 --- a/runners/hf_runner.py +++ b/runners/hf_runner.py @@ -1,7 +1,5 @@ import os from typing import Optional - -from eval.eval import compare_query_results import pandas as pd import torch from transformers import ( @@ -9,252 +7,240 @@ AutoModelForCausalLM, pipeline, ) -from utils.gen_prompt import generate_prompt -from utils.questions import prepare_questions_df -from utils.creds import db_creds_all from tqdm import tqdm from psycopg2.extensions import QueryCanceledError -from time import time import gc + +from runners.base_runner import BaseRunner +from utils.gen_prompt import generate_prompt +from utils.questions import prepare_questions_df +from utils.creds import db_creds_all +from eval.eval import compare_query_results from utils.reporting import upload_results device_map = "mps" if torch.backends.mps.is_available() else "auto" -def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]): - """ - Load a HuggingFace tokenizer and model. - You may supply either a normal huggingface model name, or a peft adapter path. - """ - if adapter_path is not None: - from peft import PeftModel, PeftConfig - - print(f"Loading adapter model {adapter_path}") - config = PeftConfig.from_pretrained(adapter_path) - tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) - model = AutoModelForCausalLM.from_pretrained( - config.base_model_name_or_path, - torch_dtype=torch.float16, - trust_remote_code=True, - use_cache=True, - device_map=device_map, - ) - print(f"Loading adapter {adapter_path}") - model = PeftModel.from_pretrained(model, adapter_path) - model = model.merge_and_unload() - print(f"Merged adapter {adapter_path}") - else: - print(f"Loading model {model_name}") - try: - tokenizer = AutoTokenizer.from_pretrained(model_name) - except: - tokenizer = AutoTokenizer.from_pretrained( - "meta-llama/Meta-Llama-3-8B-Instruct" +class HFRunner(BaseRunner): + def __init__(self): + super().__init__() + self.tokenizer = None + self.model = None + self.pipe = None + + def _initialize_model(self, model_name: Optional[str], adapter_path: Optional[str], batch_size: int): + """Load a HuggingFace tokenizer and model.""" + if adapter_path is not None: + from peft import PeftModel, PeftConfig + + print(f"Loading adapter model {adapter_path}") + config = PeftConfig.from_pretrained(adapter_path) + self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) + self.model = AutoModelForCausalLM.from_pretrained( + config.base_model_name_or_path, + torch_dtype=torch.float16, + trust_remote_code=True, + use_cache=True, + device_map=device_map, ) + print(f"Loading adapter {adapter_path}") + self.model = PeftModel.from_pretrained(self.model, adapter_path) + self.model = self.model.merge_and_unload() + print(f"Merged adapter {adapter_path}") + else: + print(f"Loading model {model_name}") + try: + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + except: + self.tokenizer = AutoTokenizer.from_pretrained( + "meta-llama/Meta-Llama-3-8B-Instruct" + ) - tokenizer.pad_token_id = tokenizer.eos_token_id - model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.float16, - trust_remote_code=True, - device_map=device_map, + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + if model_name and "8b" in model_name.lower(): + # do this since it doesn't seem to have been done by default + self.tokenizer.padding_side = "left" + + if not self.model: + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16, + trust_remote_code=True, + device_map=device_map, + ) + + self.model.tie_weights() + self.pipe = pipeline( + "text-generation", model=self.model, tokenizer=self.tokenizer, batch_size=batch_size ) - return tokenizer, model - - -def run_hf_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - model_name = args.model - adapter_path = args.adapter - output_file_list = args.output_file - k_shot = args.k_shot - db_type = args.db_type - num_beams = args.num_beams - cot_table_alias = args.cot_table_alias - if model_name is None and adapter_path is None: - raise ValueError( - "You must supply either a model name or an adapter path to run an evaluation." + def _extract_query(self, generated_text, prompt): + """Extract SQL query from response based on prompt format.""" + if "[SQL]" not in prompt: + return generated_text.split("```")[0].split(";")[0].strip() + ";" + else: + return generated_text.split("[/SQL]")[0].split(";")[0].strip() + ";" + + def _process_batch(self, batch, args): + """Process a batch of questions using HF pipeline.""" + prompts = batch["prompt"].tolist() + generated_queries = self.pipe( + prompts, + max_new_tokens=600, + do_sample=False, + num_beams=args.num_beams, + num_return_sequences=1, + return_full_text=False, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.eos_token_id, + temperature=None, + top_p=None, ) - - print(f"Questions prepared\nNow loading model...") - # initialize tokenizer and model - tokenizer, model = get_tokenizer_model(model_name, adapter_path) - - if "8b" in model_name.lower(): - # do this since it doesn't seem to have been done by default - tokenizer.padding_side = "left" - - tokenizer.pad_token_id = tokenizer.eos_token_id - model.tie_weights() - - print("model loaded\nnow generating and evaluating predictions...") - - # from here, we generate and evaluate predictions - # eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0] - pipe = pipeline( - "text-generation", model=model, tokenizer=tokenizer, batch_size=args.batch_size - ) - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - # create a prompt for each question - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], - public_data, - args.num_columns, - args.shuffle_metadata, - ), - axis=1, - ) - - total_tried = 0 - total_correct = 0 - output_rows = [] - - def chunk_dataframe(df, chunk_size): - """Yield successive chunk_size chunks from df.""" - for i in range(0, len(df), chunk_size): - yield df[i : min(i + chunk_size, len(df))] - - df_chunks = list(chunk_dataframe(df, args.batch_size)) - - with tqdm(total=len(df)) as pbar: - for batch in df_chunks: - prompts = batch["prompt"].tolist() - generated_queries = pipe( - prompts, - max_new_tokens=600, - do_sample=False, - num_beams=num_beams, - num_return_sequences=1, - return_full_text=False, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.eos_token_id, - temperature=None, - top_p=None, - ) - gc.collect() + + # Clean up GPU memory + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + results = [] + for row, result in zip(batch.to_dict("records"), generated_queries): + generated_query = self._extract_query( + result[0]["generated_text"], row["prompt"] + ) + + # More GPU cleanup + gc.collect() + if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() - for row, result in zip(batch.to_dict("records"), generated_queries): - total_tried += 1 - # we set return_full_text to False so that we don't get the prompt text in the generated text - # this simplifies our postprocessing to deal with just the truncation of the end of the query - - if "[SQL]" not in row["prompt"]: - generated_query = ( - result[0]["generated_text"] - .split("```")[0] - .split(";")[0] - .strip() - + ";" - ) - else: - generated_query = ( - result[0]["generated_text"] - .split("[/SQL]")[0] - .split(";")[0] - .strip() - + ";" - ) - - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() + row["generated_query"] = generated_query + row["latency_seconds"] = None # HF pipeline doesn't provide per-item latency + + try: + exact_match, correct = compare_query_results( + query_gold=row["query"], + query_gen=generated_query, + db_name=row["db_name"], + db_type=row["db_type"], + db_creds=db_creds_all[row["db_type"]], + question=row["question"], + query_category=row["query_category"], + table_metadata_string=row["table_metadata_string"], + decimal_points=args.decimal_points, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["error_msg"] = "" + except QueryCanceledError as e: + row["timeout"] = 1 + row["error_msg"] = f"QUERY EXECUTION TIMEOUT: {e}" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + results.append(row) + return results + + def run_eval(self, args): + """HF-specific evaluation logic with batching.""" + if args.model is None and args.adapter is None: + raise ValueError( + "You must supply either a model name or an adapter path to run an evaluation." + ) - row["generated_query"] = generated_query - row["latency_seconds"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - db_creds = db_creds_all[db_type] + self._initialize_model(args.model, args.adapter, args.batch_size) - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds, - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - if correct: - total_correct += 1 - except QueryCanceledError as e: - row["timeout"] = 1 - row["error_msg"] = f"QUERY EXECUTION TIMEOUT: {e}" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + for questions_file, prompt_file, output_file in zip( + args.questions_file, args.prompt_file, args.output_file + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + ) + # Create prompts for all questions + df["prompt"] = df.apply( + lambda row: generate_prompt( + prompt_file, + row["question"], + row["db_name"], + row["db_type"], + row["instructions"], + row["k_shot_prompt"], + row["glossary"], + row["table_metadata_string"], + row["prev_invalid_sql"], + row["prev_error_msg"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), + not args.use_private_data, + args.num_columns, + args.shuffle_metadata, + ), + axis=1, + ) - output_rows.append(row) - pbar.update(1) + # Process in batches + def chunk_dataframe(df, chunk_size): + for i in range(0, len(df), chunk_size): + yield df[i : min(i + chunk_size, len(df))] + + df_chunks = list(chunk_dataframe(df, args.batch_size)) + all_results = [] + total_tried = total_correct = 0 + + with tqdm(total=len(df)) as pbar: + for batch in df_chunks: + batch_results = self._process_batch(batch, args) + all_results.extend(batch_results) + + # Update progress stats + batch_correct = sum(1 for r in batch_results if r.get("correct", 0)) + total_correct += batch_correct + total_tried += len(batch) + pbar.update(len(batch)) pbar.set_description( f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) - output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - output_df.to_csv(output_file, index=False, float_format="%.2f") + # Save results + results_df = pd.DataFrame(all_results) + if "prompt" in results_df.columns: + del results_df["prompt"] + + print(results_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) + results_df = results_df.sort_values(by=["db_name", "query_category", "question"]) + + # Save to file + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + results_df.to_csv(output_file, index=False, float_format="%.2f") + print(f"Saved results to {output_file}") + + # Upload results if needed + if args.upload_url is not None: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=results_df.to_dict("records"), + url=args.upload_url, + runner_type="hf_runner", + prompt=prompt, + args=args, + ) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="hf_runner", - prompt=prompt, - args=args, - ) + +def run_hf_eval(args): + runner = HFRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/llama_cpp_runner.py b/runners/llama_cpp_runner.py index 0297ca0..f64d7c3 100644 --- a/runners/llama_cpp_runner.py +++ b/runners/llama_cpp_runner.py @@ -1,155 +1,150 @@ import os - -from eval.eval import compare_query_results import pandas as pd +from tqdm import tqdm +from time import time +from llama_cpp import Llama + +from runners.base_runner import BaseRunner from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time +from eval.eval import compare_query_results from utils.reporting import upload_results -from llama_cpp import Llama -def process_row(llm, row, args): - start_time = time() - prompt = row["prompt"] - generated_query = ( - llm( +class LlamaCppRunner(BaseRunner): + def __init__(self): + super().__init__() + self.llm = None + + def _initialize_model(self, model_path: str): + """Initialize the Llama CPP model.""" + self.llm = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096) + + def _call_llm(self, prompt, model_name, temperature=0.0): + """Call Llama CPP model.""" + # model_name is unused, as it's already loaded in _initialize_model + response = self.llm( prompt, max_tokens=512, temperature=0, top_p=1, echo=False, repeat_penalty=1.0, - )["choices"][0]["text"] - .split(";")[0] - .split("```")[0] - .strip() - + ";" - ) - end_time = time() - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - return row + # Create a response object similar to other runners + class LlamaResponse: + def __init__(self, content): + self.content = content + self.input_tokens = 0 # Llama CPP doesn't provide token counts + self.output_tokens = 0 + return LlamaResponse(response["choices"][0]["text"]) -def run_llama_cpp_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - model_path = args.model - output_file_list = args.output_file - k_shot = args.k_shot - db_type = args.db_type - cot_table_alias = args.cot_table_alias - - llm = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096) - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - # create a prompt for each question - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], - public_data, - args.num_columns, - args.shuffle_metadata, - ), - axis=1, - ) + def _extract_query(self, response_content): + """Extract SQL query from response.""" + return response_content.split(";")[0].split("```")[0].strip() + ";" - total_tried = 0 - total_correct = 0 - output_rows = [] - - with tqdm(total=len(df)) as pbar: - for row in df.to_dict("records"): - row = process_row(llm, row, args) - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + def run_eval(self, args): + """Llama CPP-specific evaluation logic.""" + if not args.model: + raise ValueError("Model path must be provided for Llama CPP runner") + + self._initialize_model(args.model) + print(f"Initialized Llama CPP model from {args.model}") - output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="llama_cpp_runner", - prompt=prompt, - args=args, + for questions_file, prompt_file, output_file in zip( + args.questions_file, args.prompt_file, args.output_file + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) + df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + ) + + # Create prompts for all questions + df["prompt"] = df.apply( + lambda row: generate_prompt( + prompt_file, + row["question"], + row["db_name"], + row["db_type"], + row["instructions"], + row["k_shot_prompt"], + row["glossary"], + row["table_metadata_string"], + row["prev_invalid_sql"], + row["prev_error_msg"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), + not args.use_private_data, + args.num_columns, + args.shuffle_metadata, + ), + axis=1, + ) + + # Process rows sequentially due to Llama CPP's nature + output_rows = [] + total_tried = total_correct = 0 + + with tqdm(total=len(df)) as pbar: + for row in df.to_dict("records"): + try: + start_time = time() + result = self.process_row(row, None, args) # None as model_name is unused + if result.get("correct", 0): + total_correct += 1 + total_tried += 1 + output_rows.append(result) + except Exception as e: + row["error_msg"] = f"PROCESSING ERROR: {str(e)}" + row["error_db_exec"] = 1 + output_rows.append(row) + finally: + pbar.update(1) + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + ) + + # Save results + output_df = pd.DataFrame(output_rows) + if "prompt" in output_df.columns: + del output_df["prompt"] + + print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + + # Save to file + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + try: + output_df.to_csv(output_file, index=False, float_format="%.2f") + except: + output_df.to_pickle(output_file) + print(f"Saved results to {output_file}") + + # Upload results if needed + if args.upload_url is not None: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="llama_cpp_runner", + prompt=prompt, + args=args, + ) + + +def run_llama_cpp_eval(args): + runner = LlamaCppRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/mistral_runner.py b/runners/mistral_runner.py index 4abdf81..01f6b64 100644 --- a/runners/mistral_runner.py +++ b/runners/mistral_runner.py @@ -1,248 +1,233 @@ import os +import pandas as pd +from tqdm import tqdm from time import time from concurrent.futures import ThreadPoolExecutor, as_completed from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage -import pandas as pd -from tqdm import tqdm -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.gen_prompt import to_prompt_schema -from utils.dialects import convert_postgres_ddl_to_dialect +from runners.base_runner import BaseRunner +from utils.gen_prompt import generate_prompt as base_generate_prompt from utils.questions import prepare_questions_df +from utils.creds import db_creds_all +from eval.eval import compare_query_results from utils.reporting import upload_results -api_key = os.environ.get("MISTRAL_API_KEY") -client = MistralClient(api_key=api_key) - - -def generate_prompt( - prompt_file, - question, - db_name, - db_type, - instructions="", - k_shot_prompt="", - glossary="", - table_metadata_string="", - prev_invalid_sql="", - prev_error_msg="", - public_data=True, - shuffle=True, -): - with open(prompt_file, "r") as f: - prompt = f.read() - - # Check that System and User prompts are in the prompt file - if "System:" not in prompt or "User:" not in prompt: - raise ValueError("Invalid prompt file. Please use prompt_mistral.md") - sys_prompt = prompt.split("System:")[1].split("User:")[0].strip() - user_prompt = prompt.split("User:")[1].strip() - - if table_metadata_string == "": - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - md = dbs[db_name]["table_metadata"] - metadata_ddl = to_prompt_schema(md, shuffle) - metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - # get join_str from column_join - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - - user_prompt = user_prompt.format( - user_question=question, - instructions=instructions, - table_metadata_string=pruned_metadata_str, - k_shot_prompt=k_shot_prompt, - glossary=glossary, - prev_invalid_sql=prev_invalid_sql, - prev_error_msg=prev_error_msg, - ) - messages = [ - ChatMessage( - role="system", - content=sys_prompt, - ), - ChatMessage( - role="user", - content=user_prompt, - ), - ] - return messages - - -def process_row(row, model, args): - start_time = time() - chat_response = client.chat( - model=model, - messages=row["prompt"], - temperature=0, - max_tokens=600, - ) - end_time = time() - generated_query = chat_response.choices[0].message.content - - try: - # replace all backslashes with empty string - generated_query = generated_query.replace("\\", "") - - generated_query = generated_query.split(";")[0].split("```sql")[-1].strip() - generated_query = [i for i in generated_query.split("```") if i.strip() != ""][ - 0 - ] + ";" - except Exception as e: - print(e) - generated_query = chat_response.choices[0].message.content - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - return row +class MistralRunner(BaseRunner): + def __init__(self): + super().__init__() + self.api_key = os.environ.get("MISTRAL_API_KEY") + self.client = MistralClient(api_key=self.api_key) + def generate_prompt(self, prompt_file, **kwargs): + """Mistral-specific prompt generation.""" + with open(prompt_file, "r") as f: + prompt = f.read() -def run_mistral_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - model = args.model - output_file_list = args.output_file - k_shot = args.k_shot - max_workers = args.parallel_threads - db_type = args.db_type - cot_table_alias = args.cot_table_alias - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" + # Check that System and User prompts are in the prompt file + if "System:" not in prompt or "User:" not in prompt: + raise ValueError("Invalid prompt file. Please use prompt_mistral.md") + + sys_prompt = prompt.split("System:")[1].split("User:")[0].strip() + user_prompt = prompt.split("User:")[1].strip() + + # Get table metadata using base generate_prompt + table_metadata_str = base_generate_prompt( + prompt_file=prompt_file, + question=kwargs.get("question", ""), + db_name=kwargs.get("db_name", ""), + db_type=kwargs.get("db_type", ""), + instructions=kwargs.get("instructions", ""), + k_shot_prompt=kwargs.get("k_shot_prompt", ""), + glossary=kwargs.get("glossary", ""), + table_metadata_string=kwargs.get("table_metadata_string", ""), + prev_invalid_sql=kwargs.get("prev_invalid_sql", ""), + prev_error_msg=kwargs.get("prev_error_msg", ""), + public_data=kwargs.get("public_data", True), + shuffle_metadata=kwargs.get("shuffle_metadata", True), ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias + + # Format user prompt + user_prompt = user_prompt.format( + user_question=kwargs.get("question", ""), + instructions=kwargs.get("instructions", ""), + table_metadata_string=table_metadata_str, + k_shot_prompt=kwargs.get("k_shot_prompt", ""), + glossary=kwargs.get("glossary", ""), + prev_invalid_sql=kwargs.get("prev_invalid_sql", ""), + prev_error_msg=kwargs.get("prev_error_msg", ""), ) - # create a prompt for each question - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - public_data, - args.shuffle_metadata, - ), - axis=1, + + return [ + ChatMessage(role="system", content=sys_prompt), + ChatMessage(role="user", content=user_prompt), + ] + + def _call_llm(self, messages, model_name, temperature=0.0): + """Call Mistral API.""" + chat_response = self.client.chat( + model=model_name, + messages=messages, + temperature=temperature, + max_tokens=600, ) + # Create a response object similar to other runners + class MistralResponse: + def __init__(self, content): + self.content = content + self.input_tokens = 0 # Mistral doesn't provide token counts in this way + self.output_tokens = 0 - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model, args)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row.get("correct", 0): - total_correct += 1 - total_tried += 1 - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + return MistralResponse(chat_response.choices[0].message.content) - output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) + def _extract_query(self, response_content): + """Extract SQL query from response with Mistral-specific handling.""" try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) + # Replace backslashes + content = response_content.replace("\\", "") + # First try to extract from SQL code blocks + query = content.split(";")[0].split("```sql")[-1].strip() + query = [i for i in query.split("```") if i.strip() != ""][0] + ";" + return query + except Exception as e: + # Fallback to raw content + print(f"Query extraction error: {e}") + return response_content.split(";")[0].strip() + ";" + + def process_row(self, row, model_name, args): + """Override process_row to use Mistral prompt generation.""" + start_time = time() + try: + messages = self.generate_prompt( + prompt_file=args.prompt_file[0], + question=row["question"], + db_name=row["db_name"], + db_type=args.db_type, + instructions=row["instructions"], + k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], + prev_invalid_sql=row["prev_invalid_sql"], + prev_error_msg=row["prev_error_msg"], + public_data=not args.use_private_data, + shuffle_metadata=args.shuffle_metadata, + ) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="mistral_runner", - prompt=prompt, - args=args, + response = self._call_llm(messages, model_name) + generated_query = self._extract_query(response.content) + + result = { + "generated_query": generated_query, + "latency_seconds": time() - start_time, + "tokens_used": response.input_tokens + response.output_tokens + } + + # Run comparison + try: + exact_match, correct = compare_query_results( + query_gold=row["query"], + query_gen=generated_query, + db_name=row["db_name"], + db_type=row["db_type"], + db_creds=db_creds_all[row["db_type"]], + question=row["question"], + query_category=row["query_category"], + table_metadata_string=row["table_metadata_string"], + decimal_points=args.decimal_points, + ) + result["exact_match"] = int(exact_match) + result["correct"] = int(correct) + result["error_msg"] = "" + except Exception as e: + result["error_db_exec"] = 1 + result["error_msg"] = f"QUERY EXECUTION ERROR: {str(e)}" + + return {**row, **result} + + except Exception as e: + return { + **row, + "generated_query": "", + "exact_match": 0, + "correct": 0, + "error_db_exec": 1, + "error_msg": f"PROCESSING ERROR: {str(e)}", + "latency_seconds": time() - start_time, + "tokens_used": 0 + } + + def run_eval(self, args): + """Mistral-specific evaluation logic.""" + if not self.api_key: + raise ValueError("MISTRAL_API_KEY environment variable must be set") + + for questions_file, prompt_file, output_file in zip( + args.questions_file, args.prompt_file, args.output_file + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias ) + + # Process all rows + output_rows = [] + total_tried = total_correct = 0 + + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in df.to_dict("records"): + futures.append( + executor.submit(self.process_row, row, args.model, args) + ) + + with tqdm(as_completed(futures), total=len(futures)) as pbar: + for f in pbar: + row = f.result() + output_rows.append(row) + if row.get("correct", 0): + total_correct += 1 + total_tried += 1 + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + ) + + # Save results + output_df = pd.DataFrame(output_rows) + if "prompt" in output_df.columns: + del output_df["prompt"] + + print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + + # Save to file + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + try: + output_df.to_csv(output_file, index=False, float_format="%.2f") + except: + output_df.to_pickle(output_file) + print(f"Saved results to {output_file}") + + # Upload results if needed + if args.upload_url is not None: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="mistral_runner", + prompt=prompt, + args=args, + ) + + +def run_mistral_eval(args): + runner = MistralRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/mlx_runner.py b/runners/mlx_runner.py index e773008..4fda1ce 100644 --- a/runners/mlx_runner.py +++ b/runners/mlx_runner.py @@ -1,149 +1,150 @@ import os - -from eval.eval import compare_query_results import pandas as pd +from tqdm import tqdm +from time import time + +from runners.base_runner import BaseRunner from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time +from eval.eval import compare_query_results from utils.reporting import upload_results from mlx_lm import load, generate -def process_row(model, tokenizer, row, args): - start_time = time() - prompt = row["prompt"] - - generated_query = ( - generate(model, tokenizer, prompt=prompt, max_tokens=512, temp=0, verbose=True) - .split(";")[0] - .split("```")[0] - .strip() - + ";" - ) - end_time = time() - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - return row +class MLXRunner(BaseRunner): + def __init__(self): + super().__init__() + self.model = None + self.tokenizer = None + def _initialize_model(self, model_path: str): + """Initialize the MLX model.""" + self.model, self.tokenizer = load(model_path) -def run_mlx_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - model_path = args.model - output_file_list = args.output_file - k_shot = args.k_shot - db_type = args.db_type - cot_table_alias = args.cot_table_alias - - model, tokenizer = load(model_path) - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - # create a prompt for each question - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], - public_data, - args.num_columns, - args.shuffle_metadata, - ), - axis=1, + def _call_llm(self, prompt, model_name, temperature=0.0): + """Call MLX model.""" + # model_name is unused, as it's already loaded in _initialize_model + text = generate( + self.model, + self.tokenizer, + prompt=prompt, + max_tokens=512, + temp=0, + verbose=True ) + # Create a response object similar to other runners + class MLXResponse: + def __init__(self, content): + self.content = content + self.input_tokens = 0 # MLX doesn't provide token counts + self.output_tokens = 0 - total_tried = 0 - total_correct = 0 - output_rows = [] - - with tqdm(total=len(df)) as pbar: - for row in df.to_dict("records"): - row = process_row(model, tokenizer, row, args) - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + return MLXResponse(text) - output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="mlx_runner", - prompt=prompt, - args=args, + def _extract_query(self, response_content): + """Extract SQL query from response.""" + return response_content.split(";")[0].split("```")[0].strip() + ";" + + def run_eval(self, args): + """MLX-specific evaluation logic.""" + if not args.model: + raise ValueError("Model path must be provided for MLX runner") + + self._initialize_model(args.model) + print(f"Initialized MLX model from {args.model}") + + for questions_file, prompt_file, output_file in zip( + args.questions_file, args.prompt_file, args.output_file + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias ) + + # Create prompts for all questions + df["prompt"] = df.apply( + lambda row: generate_prompt( + prompt_file, + row["question"], + row["db_name"], + row["db_type"], + row["instructions"], + row["k_shot_prompt"], + row["glossary"], + row["table_metadata_string"], + row["prev_invalid_sql"], + row["prev_error_msg"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), + not args.use_private_data, + args.num_columns, + args.shuffle_metadata, + ), + axis=1, + ) + + # Process rows sequentially due to MLX's nature + output_rows = [] + total_tried = total_correct = 0 + + with tqdm(total=len(df)) as pbar: + for row in df.to_dict("records"): + try: + result = self.process_row(row, None, args) # None as model_name is unused + if result.get("correct", 0): + total_correct += 1 + total_tried += 1 + output_rows.append(result) + except Exception as e: + row["error_msg"] = f"PROCESSING ERROR: {str(e)}" + row["error_db_exec"] = 1 + output_rows.append(row) + finally: + pbar.update(1) + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + ) + + # Save results + output_df = pd.DataFrame(output_rows) + if "prompt" in output_df.columns: + del output_df["prompt"] + + print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + + # Save to file + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + try: + output_df.to_csv(output_file, index=False, float_format="%.2f") + except: + output_df.to_pickle(output_file) + print(f"Saved results to {output_file}") + + # Upload results if needed + if args.upload_url is not None: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="mlx_runner", + prompt=prompt, + args=args, + ) + + +def run_mlx_eval(args): + runner = MLXRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/openai_runner.py b/runners/openai_runner.py index 5d207ef..c31213c 100644 --- a/runners/openai_runner.py +++ b/runners/openai_runner.py @@ -1,260 +1,148 @@ -import os -from time import time -from concurrent.futures import ThreadPoolExecutor, as_completed import json +import os -import pandas as pd -import sqlparse -from tqdm import tqdm - -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.dialects import convert_postgres_ddl_to_dialect -from utils.gen_prompt import to_prompt_schema -from utils.questions import prepare_questions_df -from utils.reporting import upload_results from utils.llm import chat_openai +from runners.base_runner import BaseRunner, format_sql_query -def generate_prompt( - prompt_file, - question, - db_name, - db_type, - instructions="", - k_shot_prompt="", - glossary="", - table_metadata_string="", - prev_invalid_sql="", - prev_error_msg="", - public_data=True, - shuffle=True, -): - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - with open(prompt_file, "r") as f: - prompt = json.load(f) - - if table_metadata_string == "": - md = dbs[db_name]["table_metadata"] - pruned_metadata_ddl = to_prompt_schema(md, shuffle) - pruned_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = pruned_metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - - if prompt[0]["role"] == "system": - prompt[0]["content"] = prompt[0]["content"].format( - db_type=db_type, - ) - prompt[1]["content"] = prompt[1]["content"].format( - user_question=question, - instructions=instructions, - table_metadata_string=pruned_metadata_str, - k_shot_prompt=k_shot_prompt, - ) - else: - prompt[0]["content"] = prompt[1]["content"].format( - db_type=db_type, - user_question=question, - instructions=instructions, - table_metadata_string=pruned_metadata_str, - k_shot_prompt=k_shot_prompt, - ) - return prompt +class OpenAIRunner(BaseRunner): + def _call_llm(self, prompt, model_name, temperature=0.0): + """Call OpenAI API, handling both JSON and text prompts.""" + if isinstance(prompt, list): # JSON prompt format + messages = prompt + else: # Text prompt format + messages = [{"role": "user", "content": prompt}] + return chat_openai(messages=messages, model=model_name, temperature=temperature) - -def process_row(row, model_name, args): - start_time = time() - messages = generate_prompt( - prompt_file=args.prompt_file[0], - question=row["question"], - db_name=row["db_name"], - db_type=args.db_type, - instructions=row["instructions"], - k_shot_prompt=row["k_shot_prompt"], - glossary=row["glossary"], - table_metadata_string=row["table_metadata_string"], - prev_invalid_sql=row["prev_invalid_sql"], - prev_error_msg=row["prev_error_msg"], - public_data=not args.use_private_data, - shuffle=args.shuffle_metadata, - ) - try: - response = chat_openai(messages=messages, model=model_name, temperature=0.0) - generated_query = ( - response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() - ) + def _extract_query(self, response_content): + """Extract SQL query from response.""" try: - generated_query = sqlparse.format( - generated_query, reindent=True, keyword_case="upper" - ) + return response_content.split("```sql", 1)[-1].split("```", 1)[0].strip() except: - pass - return { - "query": generated_query, - "reason": "", - "err": "", - "latency_seconds": time() - start_time, - "tokens_used": response.input_tokens + response.output_tokens, - } - except Exception as e: - return { - "query": "", - "reason": "", - "err": f"GENERATION ERROR: {str(e)}", - "latency_seconds": time() - start_time, - "tokens_used": 0, - } - - -def run_openai_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - output_file_list = args.output_file - num_questions = args.num_questions - k_shot = args.k_shot - db_type = args.db_type - cot_table_alias = args.cot_table_alias - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - question_query_df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - input_rows = question_query_df.to_dict("records") - output_rows = [] - with ThreadPoolExecutor(args.parallel_threads) as executor: - futures = [] - for row in input_rows: - generated_query_fut = executor.submit( - process_row, - row=row, - model_name=args.model, - args=args, - ) - futures.append(generated_query_fut) - - total_tried = 0 - total_correct = 0 - for f in (pbar := tqdm(as_completed(futures), total=len(futures))): - total_tried += 1 - i = futures.index(f) - row = input_rows[i] - result_dict = f.result() - query_gen = result_dict["query"] - reason = result_dict["reason"] - err = result_dict["err"] - # save custom metrics - if "latency_seconds" in result_dict: - row["latency_seconds"] = result_dict["latency_seconds"] - if "tokens_used" in result_dict: - row["tokens_used"] = result_dict["tokens_used"] - row["generated_query"] = query_gen - row["reason"] = reason - row["error_msg"] = err - # save failures into relevant columns in the dataframe - if "GENERATION ERROR" in err: - row["error_query_gen"] = 1 - else: - expected_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - try: - is_correct = compare_query_results( - query_gold=expected_query, - query_gen=query_gen, - db_name=db_name, - db_type=db_type, - question=row["question"], - query_category=row["query_category"], - db_creds=db_creds_all[db_type], - ) - if is_correct: - total_correct += 1 - row["is_correct"] = 1 - row["error_msg"] = "" - else: - row["is_correct"] = 0 - row["error_msg"] = "INCORRECT RESULTS" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"EXECUTION ERROR: {str(e)}" - output_rows.append(row) - pbar.set_description( - f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" + # Fallback to extract anything that looks like SQL + return response_content.split(";")[0].strip() + ";" + + def run_eval(self, args): + """OpenAI-specific evaluation logic.""" + questions_file_list = args.questions_file + prompt_file_list = args.prompt_file + output_file_list = args.output_file + + for questions_file, prompt_file, output_file in zip( + questions_file_list, prompt_file_list, output_file_list + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + question_query_df = self.prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + ) + input_rows = question_query_df.to_dict("records") + output_rows = [] + + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in input_rows: + generated_query_fut = executor.submit( + self.process_row, + row=row, + model_name=args.model, + args=args, + ) + futures.append(generated_query_fut) + + total_tried = 0 + total_correct = 0 + for f in (pbar := tqdm(as_completed(futures), total=len(futures))): + total_tried += 1 + i = futures.index(f) + row = input_rows[i] + result_dict = f.result() + query_gen = result_dict["query"] + reason = result_dict["reason"] + err = result_dict["err"] + # save custom metrics + if "latency_seconds" in result_dict: + row["latency_seconds"] = result_dict["latency_seconds"] + if "tokens_used" in result_dict: + row["tokens_used"] = result_dict["tokens_used"] + row["generated_query"] = query_gen + row["reason"] = reason + row["error_msg"] = err + # save failures into relevant columns in the dataframe + if "GENERATION ERROR" in err: + row["error_query_gen"] = 1 + else: + expected_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + try: + is_correct = self.compare_query_results( + query_gold=expected_query, + query_gen=query_gen, + db_name=db_name, + db_type=db_type, + question=row["question"], + query_category=row["query_category"], + db_creds=self.db_creds_all[db_type], + ) + if is_correct: + total_correct += 1 + row["is_correct"] = 1 + row["error_msg"] = "" + else: + row["is_correct"] = 0 + row["error_msg"] = "INCORRECT RESULTS" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"EXECUTION ERROR: {str(e)}" + output_rows.append(row) + pbar.set_description( + f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" + ) + + # save results to csv + output_df = pd.DataFrame(output_rows) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + if "prompt" in output_df.columns: + del output_df["prompt"] + # get num rows, mean correct, mean error_db_exec for each query_category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("is_correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), ) - - # save results to csv - output_df = pd.DataFrame(output_rows) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - if "prompt" in output_df.columns: - del output_df["prompt"] - # get num rows, mean correct, mean error_db_exec for each query_category - agg_stats = ( - output_df.groupby("query_category") - .agg( - num_rows=("db_name", "count"), - mean_correct=("is_correct", "mean"), - mean_error_db_exec=("error_db_exec", "mean"), + .reset_index() ) - .reset_index() - ) - print(agg_stats) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - output_df.to_csv(output_file, index=False, float_format="%.2f") - - # get average rate of correct results - avg_subset = output_df["correct"].sum() / len(output_df) - print(f"Average correct rate: {avg_subset:.2f}") - - results = output_df.to_dict("records") + print(agg_stats) + # get directory of output_file and create if not exist + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_df.to_csv(output_file, index=False, float_format="%.2f") + + # get average rate of correct results + avg_subset = output_df["correct"].sum() / len(output_df) + print(f"Average correct rate: {avg_subset:.2f}") + + results = output_df.to_dict("records") + + # upload results + with open(prompt_file, "r") as f: + prompt = f.read() + if args.upload_url is not None: + upload_results( + results=results, + url=args.upload_url, + runner_type="openai", + prompt=prompt, + args=args, + ) - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="openai", - prompt=prompt, - args=args, - ) +def run_openai_eval(args): + runner = OpenAIRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/together_runner.py b/runners/together_runner.py index 0414e57..9b726f9 100644 --- a/runners/together_runner.py +++ b/runners/together_runner.py @@ -1,174 +1,197 @@ import os +import pandas as pd +from tqdm import tqdm +from time import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict -from eval.eval import compare_query_results -import pandas as pd +from runners.base_runner import BaseRunner from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time -from together import Together +from eval.eval import compare_query_results from utils.reporting import upload_results +from together import Together -client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) - - -def process_row(row: Dict, model: str): - start_time = time() - if model.startswith("meta-llama"): - stop = ["<|eot_id|>", "<|eom_id|>"] - else: - print( - "Undefined stop token(s). Please specify the stop token(s) for the model." - ) - stop = [] - messages = row["prompt"] - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=800, - temperature=0.0, - stop=stop, - stream=False, - ) - content = response.choices[0].message.content - generated_query = content.split("```", 1)[0].strip() - end_time = time() - - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, +class TogetherRunner(BaseRunner): + def __init__(self): + super().__init__() + self.client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) + + def _get_stop_tokens(self, model_name: str): + """Get stop tokens based on model type.""" + if model_name.startswith("meta-llama"): + return ["<|eot_id|>", "<|eom_id|>"] + else: + print("Undefined stop token(s). Please specify the stop token(s) for the model.") + return [] + + def _call_llm(self, messages, model_name, temperature=0.0): + """Call Together API.""" + stop_tokens = self._get_stop_tokens(model_name) + response = self.client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=800, + temperature=temperature, + stop=stop_tokens, + stream=False, ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + # Create a response object similar to other runners + class TogetherResponse: + def __init__(self, content): + self.content = content + self.input_tokens = 0 # Together doesn't provide token counts this way + self.output_tokens = 0 + + return TogetherResponse(response.choices[0].message.content) + + def _extract_query(self, response_content): + """Extract SQL query from response.""" + return response_content.split("```", 1)[0].strip() + + def process_row(self, row: Dict, model_name, args): + """Process a row using Together AI.""" + start_time = time() + try: + # Together uses OpenAI chat format + messages = generate_prompt( + prompt_file=args.prompt_file[0], + question=row["question"], + db_name=row["db_name"], + db_type=args.db_type, + instructions=row["instructions"], + k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], + prev_invalid_sql=row["prev_invalid_sql"], + prev_error_msg=row["prev_error_msg"], + question_0=row.get("question_0", ""), + query_0=row.get("query_0", ""), + question_1=row.get("question_1", ""), + query_1=row.get("query_1", ""), + cot_instructions=row.get("cot_instructions", ""), + cot_pregen=row.get("cot_pregen", False), + public_data=not args.use_private_data, + columns_to_keep=args.num_columns, + shuffle_metadata=args.shuffle_metadata, + table_aliases=row.get("table_aliases", ""), + ) - return row + response = self._call_llm(messages, model_name) + generated_query = self._extract_query(response.content) + + result = { + "generated_query": generated_query, + "latency_seconds": time() - start_time, + "tokens_used": None # Together doesn't provide token counts + } + + # Run comparison + try: + exact_match, correct = compare_query_results( + query_gold=row["query"], + query_gen=generated_query, + db_name=row["db_name"], + db_type=row["db_type"], + db_creds=db_creds_all[row["db_type"]], + question=row["question"], + query_category=row["query_category"], + table_metadata_string=row["table_metadata_string"], + ) + result["exact_match"] = int(exact_match) + result["correct"] = int(correct) + result["error_msg"] = "" + except Exception as e: + result["error_db_exec"] = 1 + result["error_msg"] = f"QUERY EXECUTION ERROR: {str(e)}" + + return {**row, **result} + + except Exception as e: + return { + **row, + "generated_query": "", + "exact_match": 0, + "correct": 0, + "error_db_exec": 1, + "error_msg": f"PROCESSING ERROR: {str(e)}", + "latency_seconds": time() - start_time, + "tokens_used": None + } + + def run_eval(self, args): + """Together-specific evaluation logic.""" + if not args.prompt_file[0].endswith(".json"): + raise ValueError(f"Prompt file must be a JSON file. Got {args.prompt_file[0]}") + + for questions_file, prompt_file, output_file in zip( + args.questions_file, args.prompt_file, args.output_file + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) + df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + ) + # Process all rows + output_rows = [] + total_tried = total_correct = 0 + + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in df.to_dict("records"): + futures.append( + executor.submit(self.process_row, row, args.model, args) + ) -def run_together_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - output_file_list = args.output_file - k_shot = args.k_shot - max_workers = args.parallel_threads - db_type = args.db_type - decimal_points = args.decimal_points - model = args.model - cot_table_alias = args.cot_table_alias - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - if not prompt_file.endswith(".json"): - raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - # create a prompt for each question - # note that the prompt for together ai uses the openai chat API - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], - public_data, - args.num_columns, - args.shuffle_metadata, - row["table_aliases"], - ), - axis=1, - ) + with tqdm(as_completed(futures), total=len(futures)) as pbar: + for f in pbar: + row = f.result() + output_rows.append(row) + if row.get("correct", 0): + total_correct += 1 + total_tried += 1 + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" + ) + + # Save results + output_df = pd.DataFrame(output_rows) + if "prompt" in output_df.columns: + del output_df["prompt"] + + print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + + # Save to file + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + try: + output_df.to_csv(output_file, index=False, float_format="%.2f") + except: + output_df.to_pickle(output_file) + print(f"Saved results to {output_file}") + + # Upload results if needed + if args.upload_url is not None: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="together_runner", + prompt=prompt, + args=args, + ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) +def run_together_eval(args): + runner = TogetherRunner() + runner.run_eval(args) \ No newline at end of file diff --git a/runners/vllm_runner.py b/runners/vllm_runner.py index 59ed962..95527d2 100644 --- a/runners/vllm_runner.py +++ b/runners/vllm_runner.py @@ -1,219 +1,235 @@ -import json import os from typing import List import sqlparse from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest -from eval.eval import compare_query_results import pandas as pd -from utils.gen_prompt import generate_prompt -from utils.questions import prepare_questions_df -from utils.creds import db_creds_all import time import torch from transformers import AutoTokenizer from tqdm import tqdm + +from runners.base_runner import BaseRunner +from utils.gen_prompt import generate_prompt +from utils.questions import prepare_questions_df +from utils.creds import db_creds_all +from eval.eval import compare_query_results from utils.reporting import upload_results -def run_vllm_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - num_questions = args.num_questions - public_data = not args.use_private_data - model_name = args.model - output_file_list = args.output_file - num_beams = args.num_beams - k_shot = args.k_shot - db_type = args.db_type - cot_table_alias = args.cot_table_alias - enable_lora = True if args.adapter else False - lora_request = LoRARequest("sql_adapter", 1, args.adapter) if args.adapter else None - - # initialize model only once as it takes a while - print(f"Preparing {model_name}") - tokenizer = AutoTokenizer.from_pretrained(model_name) - tokenizer.pad_token_id = tokenizer.eos_token_id - if not args.quantized: - llm = LLM( - model=model_name, - tensor_parallel_size=1, - enable_lora=enable_lora, - max_model_len=4096, - max_lora_rank=64, - ) - else: - llm = LLM( - model=model_name, - tensor_parallel_size=1, - quantization="AWQ", - enable_lora=enable_lora, - max_model_len=4096, - max_lora_rank=64, - ) +class VLLMRunner(BaseRunner): + def __init__(self): + super().__init__() + self.llm = None + self.tokenizer = None + self.sampling_params = None - sampling_params = SamplingParams( - n=1, - best_of=num_beams, - use_beam_search=num_beams != 1, - stop_token_ids=[tokenizer.eos_token_id], - max_tokens=1000, - temperature=0, - ) - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias + def _initialize_model(self, args): + """Initialize VLLM model with appropriate parameters.""" + model_name = args.model + enable_lora = True if args.adapter else False + lora_request = LoRARequest("sql_adapter", 1, args.adapter) if args.adapter else None + self.lora_request = lora_request + + print(f"Preparing {model_name}") + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + if not args.quantized: + self.llm = LLM( + model=model_name, + tensor_parallel_size=1, + enable_lora=enable_lora, + max_model_len=4096, + max_lora_rank=64, + ) + else: + self.llm = LLM( + model=model_name, + tensor_parallel_size=1, + quantization="AWQ", + enable_lora=enable_lora, + max_model_len=4096, + max_lora_rank=64, + ) + + self.sampling_params = SamplingParams( + n=1, + best_of=args.num_beams, + use_beam_search=args.num_beams != 1, + stop_token_ids=[self.tokenizer.eos_token_id], + max_tokens=1000, + temperature=0, ) - # create a prompt for each question - df["prompt"] = df.apply( - lambda row: generate_prompt( - prompt_file, - row["question"], - row["db_name"], - row["db_type"], - row["instructions"], - row["k_shot_prompt"], - row["glossary"], - row["table_metadata_string"], - row["prev_invalid_sql"], - row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], - public_data, - args.num_columns, - args.shuffle_metadata, - ), - axis=1, + + def _extract_query(self, response_content): + """Extract SQL query from response.""" + try: + # Try to extract query from code blocks + query = response_content.split("```sql", 1)[-1].split("```")[0].strip() + return query.split(";")[0].strip() + ";" + except: + # Fallback to extract anything that looks like SQL + return response_content.split(";")[0].strip() + ";" + + def _process_batch(self, batch, args): + """Process a batch of questions using VLLM.""" + prompts = batch["prompt"].tolist() + print(f"Generating completions for {len(prompts)} prompts") + + # Tokenize prompts + prompt_tokens = [] + prompt_token_sizes = [] + for prompt in prompts: + token_ids = self.tokenizer.encode(prompt, add_special_tokens=False) + if token_ids[0] != self.tokenizer.bos_token_id: + token_ids = [self.tokenizer.bos_token_id] + token_ids + prompt_tokens.append(token_ids) + prompt_token_sizes.append(len(token_ids)) + + print(f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}") + + # Generate completions + start_time = time.time() + outputs = self.llm.generate( + sampling_params=self.sampling_params, + prompt_token_ids=prompt_tokens, + use_tqdm=False, + lora_request=self.lora_request, ) - print(f"Prepared {len(df)} question(s) from {questions_file}") - - def chunk_dataframe(df, chunk_size): - """Returns successive chunk_size chunks from df as a list of dfs""" - df_chunks = [] - for i in range(0, len(df), chunk_size): - df_i = df.iloc[i : min(i + chunk_size, len(df))] - print( - f"Chunk {i//chunk_size+1}/{len(df)//chunk_size+1} with {len(df_i)} questions" - ) - df_chunks.append(df_i) - return df_chunks - - df_chunks = chunk_dataframe(df, args.batch_size) - - total_tried = 0 - total_correct = 0 - output_rows = [] - - print(f"Generating completions") - - for batch in (pbar := tqdm(df_chunks, total=len(df))): - prompts = batch["prompt"].tolist() - print(f"Generating completions for {len(prompts)} prompts") - prompt_tokens = [] - prompt_token_sizes = [] - for prompt in prompts: - token_ids = tokenizer.encode(prompt, add_special_tokens=False) - # add bos token if not already present in prompt - if token_ids[0] != tokenizer.bos_token_id: - token_ids = [tokenizer.bos_token_id] + token_ids - prompt_tokens.append(token_ids) - prompt_token_sizes.append(len(token_ids)) - print( - f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}" + time_taken = time.time() - start_time + print(f"Generated {len(outputs)} completions in {time_taken:.2f} seconds") + + # Process results + results = [] + for row, output in zip(batch.to_dict("records"), outputs): + generated_query = self._extract_query(output.outputs[0].text) + generated_query = sqlparse.format( + generated_query, keyword_case="upper", strip_whitespace=True ) - start_time = time.time() - # outputs = llm.generate(prompts, sampling_params) # if you prefer to use prompts instead of token_ids - outputs = llm.generate( - sampling_params=sampling_params, - prompt_token_ids=prompt_tokens, - use_tqdm=False, - lora_request=lora_request, + + row["generated_query"] = generated_query + row["tokens_used"] = len(output.outputs[0].token_ids) + row["latency_seconds"] = time_taken / len(batch) + + try: + exact_match, correct = compare_query_results( + query_gold=row["query"], + query_gen=generated_query, + db_name=row["db_name"], + db_type=row["db_type"], + db_creds=db_creds_all[row["db_type"]], + question=row["question"], + query_category=row["query_category"], + table_metadata_string=row["table_metadata_string"], + decimal_points=args.decimal_points, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + results.append(row) + return results + + def run_eval(self, args): + """VLLM-specific evaluation logic with batching.""" + self._initialize_model(args) + + for questions_file, prompt_file, output_file in zip( + args.questions_file, args.prompt_file, args.output_file + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print(f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}") + df = prepare_questions_df( + questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias ) - print( - f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds" + + # Create prompts for all questions + df["prompt"] = df.apply( + lambda row: generate_prompt( + prompt_file, + row["question"], + row["db_name"], + row["db_type"], + row["instructions"], + row["k_shot_prompt"], + row["glossary"], + row["table_metadata_string"], + row["prev_invalid_sql"], + row["prev_error_msg"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", ""), + not args.use_private_data, + args.num_columns, + args.shuffle_metadata, + ), + axis=1, ) - time_taken = time.time() - start_time - for row, output in zip(batch.to_dict("records"), outputs): - generated_query = ( - output.outputs[0].text.split(";")[0].split("```")[0].strip() + ";" + print(f"Prepared {len(df)} question(s) from {questions_file}") + + # Process in batches + def chunk_dataframe(df, chunk_size): + df_chunks = [] + for i in range(0, len(df), chunk_size): + df_i = df.iloc[i : min(i + chunk_size, len(df))] + print(f"Chunk {i//chunk_size+1}/{len(df)//chunk_size+1} with {len(df_i)} questions") + df_chunks.append(df_i) + return df_chunks + + df_chunks = chunk_dataframe(df, args.batch_size) + all_results = [] + total_tried = total_correct = 0 + + print("Generating completions") + for batch in (pbar := tqdm(df_chunks, total=len(df_chunks))): + batch_results = self._process_batch(batch, args) + all_results.extend(batch_results) + + # Update progress stats + batch_correct = sum(1 for r in batch_results if r.get("correct", 0)) + total_correct += batch_correct + total_tried += len(batch) + pbar.set_description( + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) - normalized_query = sqlparse.format( - generated_query, keyword_case="upper", strip_whitespace=True + + # Save results + results_df = pd.DataFrame(all_results) + if "prompt" in results_df.columns: + del results_df["prompt"] + + print(results_df.groupby("query_category")[["exact_match", "correct"]].mean()) + results_df = results_df.sort_values(by=["db_name", "query_category", "question"]) + print(f"Average tokens generated: {results_df['tokens_used'].mean():.1f}") + + # Save to file + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + results_df.to_csv(output_file, index=False, float_format="%.2f") + print(f"Saved results to {output_file}") + + # Upload results if needed + if args.upload_url is not None: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=results_df.to_dict("records"), + url=args.upload_url, + runner_type="vllm_runner", + prompt=prompt, + args=args, ) - row["generated_query"] = normalized_query - row["tokens_used"] = len(output.outputs[0].token_ids) - row["latency_seconds"] = time_taken / len(batch) - - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - db_creds = db_creds_all[db_type] - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds, - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - if correct: - total_correct += 1 - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - - total_tried += 1 - output_rows.append(row) - pbar.update(len(batch)) - pbar.set_description( - f"Correct so far: {total_correct}/{(total_tried)} ({100*total_correct/(total_tried):.2f}%)" - ) - df = pd.DataFrame(output_rows) - del df["prompt"] - print(df.groupby("query_category")[["exact_match", "correct"]].mean()) - df = df.sort_values(by=["db_name", "query_category", "question"]) - print(f"Average tokens generated: {df['tokens_used'].mean():.1f}") - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - df.to_csv(output_file, index=False, float_format="%.2f") - print(f"Saved results to {output_file}") - - results = df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="vllm_runner", - prompt=prompt, - args=args, - ) + + +def run_vllm_eval(args): + runner = VLLMRunner() + runner.run_eval(args) \ No newline at end of file From f49a3ff7d7fe9236fe61593f03652d63804cd76f Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Thu, 30 Jan 2025 03:04:53 +0800 Subject: [PATCH 2/2] linting --- runners/anthropic_runner.py | 21 ++++++++---- runners/api_runner.py | 68 +++++++++++++++++++++++-------------- runners/base_runner.py | 16 ++++----- runners/bedrock_runner.py | 49 ++++++++++++++++---------- runners/deepseek_runner.py | 38 +++++++++++++-------- runners/gemini_runner.py | 36 ++++++++++++-------- runners/hf_runner.py | 49 ++++++++++++++++++-------- runners/llama_cpp_runner.py | 27 ++++++++++----- runners/mistral_runner.py | 35 ++++++++++++------- runners/mlx_runner.py | 39 +++++++++++++-------- runners/openai_runner.py | 17 +++++++--- runners/together_runner.py | 38 ++++++++++++++------- runners/vllm_runner.py | 48 +++++++++++++++++--------- 13 files changed, 311 insertions(+), 170 deletions(-) diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index 34aa791..546b667 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -18,7 +18,9 @@ def _call_llm(self, prompt, model_name, temperature=0.0): messages = prompt else: # Text prompt format messages = [{"role": "user", "content": prompt}] - return chat_anthropic(messages=messages, model=model_name, temperature=temperature) + return chat_anthropic( + messages=messages, model=model_name, temperature=temperature + ) def _extract_query(self, response_content): """Extract SQL query from response.""" @@ -33,7 +35,7 @@ def run_eval(self, args): questions_file_list = args.questions_file prompt_file_list = args.prompt_file output_file_list = args.output_file - + for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): @@ -43,11 +45,15 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) question_query_df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) input_rows = question_query_df.to_dict("records") output_rows = [] - + with ThreadPoolExecutor(args.parallel_threads) as executor: futures = [] for row in input_rows: @@ -111,7 +117,9 @@ def run_eval(self, args): # save results to csv output_df = pd.DataFrame(output_rows) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) if "prompt" in output_df.columns: del output_df["prompt"] # get num rows, mean correct, mean error_db_exec for each query_category @@ -149,6 +157,7 @@ def run_eval(self, args): args=args, ) + def run_anthropic_eval(args): runner = AnthropicRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/api_runner.py b/runners/api_runner.py index 7cf0507..8cc6a73 100644 --- a/runners/api_runner.py +++ b/runners/api_runner.py @@ -34,7 +34,9 @@ def __init__(self): self.api_url = None self.api_type = None - def _mk_vllm_json(self, prompt, num_beams, logprobs=False, sql_lora_path=None, sql_lora_name=None): + def _mk_vllm_json( + self, prompt, num_beams, logprobs=False, sql_lora_path=None, sql_lora_name=None + ): payload = { "prompt": prompt, "n": 1, @@ -67,12 +69,16 @@ def _call_llm(self, prompt, model_name, temperature=0.0): # model_name is unused but kept for BaseRunner compatibility start_time = time() json_data = None - + if self.api_type == "tgi": json_data = self._mk_tgi_json(prompt, self.num_beams) elif self.api_type == "vllm": json_data = self._mk_vllm_json( - prompt, self.num_beams, self.logprobs, self.sql_lora_path, self.sql_lora_name + prompt, + self.num_beams, + self.logprobs, + self.sql_lora_path, + self.sql_lora_name, ) else: json_data = { @@ -84,7 +90,7 @@ def _call_llm(self, prompt, model_name, temperature=0.0): "stop": [";", "```"], "max_tokens": 4000, } - + try: response = requests.post( self.api_url, @@ -130,7 +136,7 @@ def process_row(self, row, model_name, args): self.logprobs = args.logprobs self.sql_lora_path = args.adapter self.sql_lora_name = args.adapter_name - + try: prompt = generate_prompt( prompt_file=args.prompt_file[0], @@ -169,11 +175,13 @@ def process_row(self, row, model_name, args): rank = prob["rank"] logprob = prob["logprob"] token = prob["decoded_token"] - probs_to_append.update({ - f"rank_{rank}_token": token, - f"rank_{rank}_logprob": logprob, - f"rank_{rank}_prob": 10**logprob, - }) + probs_to_append.update( + { + f"rank_{rank}_token": token, + f"rank_{rank}_logprob": logprob, + f"rank_{rank}_prob": 10**logprob, + } + ) probs_to_append["prob_diff"] = ( probs_to_append["rank_1_prob"] - probs_to_append["rank_2_prob"] ) @@ -184,7 +192,7 @@ def process_row(self, row, model_name, args): "generated_query": generated_query, "latency_seconds": self.request_time, "tokens_used": None, # API doesn't provide token counts - "logprobs": logprobs_display if self.logprobs else None + "logprobs": logprobs_display if self.logprobs else None, } # Run comparison @@ -218,8 +226,10 @@ def process_row(self, row, model_name, args): "error_db_exec": 1, "error_msg": f"API ERROR: {str(e)}", "tokens_used": None, - "latency_seconds": self.request_time if hasattr(self, 'request_time') else None, - "logprobs": [] if self.logprobs else None + "latency_seconds": ( + self.request_time if hasattr(self, "request_time") else None + ), + "logprobs": [] if self.logprobs else None, } def run_eval(self, args): @@ -250,19 +260,21 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) # Process all rows output_rows = [] total_tried = total_correct = 0 - + with ThreadPoolExecutor(args.parallel_threads) as executor: futures = [] for row in df.to_dict("records"): - futures.append( - executor.submit(self.process_row, row, None, args) - ) + futures.append(executor.submit(self.process_row, row, None, args)) with tqdm(as_completed(futures), total=len(futures)) as pbar: for f in pbar: @@ -277,7 +289,7 @@ def run_eval(self, args): # Save results output_df = pd.DataFrame(output_rows) - + # Handle logprobs if needed if args.logprobs: print( @@ -291,10 +303,14 @@ def run_eval(self, args): if "prompt" in output_df.columns: del output_df["prompt"] - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - + + print( + output_df.groupby("query_category")[["correct", "error_db_exec"]].mean() + ) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) + # Save to file output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -307,7 +323,9 @@ def run_eval(self, args): # Upload results if needed if args.upload_url is not None: - run_name = args.run_name or output_file.split("/")[-1].replace(".csv", "") + run_name = args.run_name or output_file.split("/")[-1].replace( + ".csv", "" + ) upload_results( results=output_rows, url=args.upload_url, @@ -319,4 +337,4 @@ def run_eval(self, args): def run_api_eval(args): runner = APIRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/base_runner.py b/runners/base_runner.py index 6315725..b9eadf1 100644 --- a/runners/base_runner.py +++ b/runners/base_runner.py @@ -69,9 +69,7 @@ def generate_prompt( def format_sql_query(generated_query): try: - return sqlparse.format( - generated_query, reindent=True, keyword_case="upper" - ) + return sqlparse.format(generated_query, reindent=True, keyword_case="upper") except: return generated_query @@ -148,11 +146,11 @@ def run_eval(self, args): print(f"Using prompt file {prompt_file}") print("Preparing questions...") question_query_df = prepare_questions_df( - questions_file, - args.db_type, - args.num_questions, - args.k_shot, - args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) input_rows = question_query_df.to_dict("records") @@ -170,4 +168,4 @@ def run_eval(self, args): results_df = pd.DataFrame(results) results_df.to_csv(output_file, index=False) - print(f"Results saved to {output_file}") \ No newline at end of file + print(f"Results saved to {output_file}") diff --git a/runners/bedrock_runner.py b/runners/bedrock_runner.py index 0c7a0cf..4aa0524 100644 --- a/runners/bedrock_runner.py +++ b/runners/bedrock_runner.py @@ -21,20 +21,22 @@ def __init__(self): def _call_llm(self, prompt, model_name, temperature=0.0): """Call AWS Bedrock API.""" - body = json.dumps({ - "prompt": prompt, - "max_gen_len": 600, - "temperature": 0, - "top_p": 1, - }) - + body = json.dumps( + { + "prompt": prompt, + "max_gen_len": 600, + "temperature": 0, + "top_p": 1, + } + ) + accept = "application/json" contentType = "application/json" response = self.client.invoke_model( body=body, modelId=model_name, accept=accept, contentType=contentType ) model_response = json.loads(response["body"].read()) - + # Create a response object similar to other runners class BedrockResponse: def __init__(self, content): @@ -46,7 +48,10 @@ def __init__(self, content): def _extract_query(self, response_content): """Extract SQL query from response.""" - return response_content.split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";" + return ( + response_content.split("```sql")[-1].split("```")[0].split(";")[0].strip() + + ";" + ) def process_row(self, row, model_name, args): """Override process_row to use simple handling.""" @@ -80,7 +85,7 @@ def process_row(self, row, model_name, args): result = { "generated_query": generated_query, "latency_seconds": time() - start_time, - "tokens_used": None # Bedrock doesn't provide token counts + "tokens_used": None, # Bedrock doesn't provide token counts } # Run comparison @@ -114,7 +119,7 @@ def process_row(self, row, model_name, args): "error_db_exec": 1, "error_msg": f"PROCESSING ERROR: {str(e)}", "latency_seconds": time() - start_time, - "tokens_used": None + "tokens_used": None, } def run_eval(self, args): @@ -128,13 +133,17 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) # Process all rows output_rows = [] total_tried = total_correct = 0 - + with ThreadPoolExecutor(args.parallel_threads) as executor: futures = [] for row in df.to_dict("records"): @@ -157,10 +166,14 @@ def run_eval(self, args): output_df = pd.DataFrame(output_rows) if "prompt" in output_df.columns: del output_df["prompt"] - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - + + print( + output_df.groupby("query_category")[["correct", "error_db_exec"]].mean() + ) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) + # Save to file output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -186,4 +199,4 @@ def run_eval(self, args): def run_bedrock_eval(args): runner = BedrockRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/deepseek_runner.py b/runners/deepseek_runner.py index e610be5..92ba069 100644 --- a/runners/deepseek_runner.py +++ b/runners/deepseek_runner.py @@ -18,8 +18,8 @@ class DeepseekRunner(BaseRunner): def __init__(self): super().__init__() self.client = OpenAI( - base_url="https://api.deepseek.com", - api_key=os.environ.get("DEEPSEEK_API_KEY") + base_url="https://api.deepseek.com", + api_key=os.environ.get("DEEPSEEK_API_KEY"), ) def _call_llm(self, messages, model_name, temperature=0.0): @@ -37,7 +37,7 @@ def _call_llm(self, messages, model_name, temperature=0.0): messages=messages, max_tokens=800, ) - + # Create a response object similar to other runners class DeepseekResponse: def __init__(self, content): @@ -85,7 +85,7 @@ def process_row(self, row: Dict, model_name, args): result = { "generated_query": generated_query, "latency_seconds": time() - start_time, - "tokens_used": None # Deepseek doesn't provide token counts + "tokens_used": None, # Deepseek doesn't provide token counts } # Run comparison @@ -118,14 +118,16 @@ def process_row(self, row: Dict, model_name, args): "error_db_exec": 1, "error_msg": f"PROCESSING ERROR: {str(e)}", "latency_seconds": time() - start_time, - "tokens_used": None + "tokens_used": None, } def run_eval(self, args): """Deepseek-specific evaluation logic.""" if not args.prompt_file[0].endswith(".json"): - raise ValueError(f"Prompt file must be a JSON file. Got {args.prompt_file[0]}") - + raise ValueError( + f"Prompt file must be a JSON file. Got {args.prompt_file[0]}" + ) + for questions_file, prompt_file, output_file in zip( args.questions_file, args.prompt_file, args.output_file ): @@ -135,13 +137,17 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) # Process all rows output_rows = [] total_tried = total_correct = 0 - + with ThreadPoolExecutor(args.parallel_threads) as executor: futures = [] for row in df.to_dict("records"): @@ -164,10 +170,14 @@ def run_eval(self, args): output_df = pd.DataFrame(output_rows) if "prompt" in output_df.columns: del output_df["prompt"] - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - + + print( + output_df.groupby("query_category")[["correct", "error_db_exec"]].mean() + ) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) + # Save to file output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -193,4 +203,4 @@ def run_eval(self, args): def run_deepseek_eval(args): runner = DeepseekRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/gemini_runner.py b/runners/gemini_runner.py index 3ba16e9..31457a1 100644 --- a/runners/gemini_runner.py +++ b/runners/gemini_runner.py @@ -59,7 +59,7 @@ def process_row(self, row, model_name, args): response = self._call_llm(prompt, model_name) generated_query = self._extract_query(response.content) - + row["generated_query"] = generated_query row["latency_seconds"] = time() - start_time row["tokens_used"] = response.input_tokens + response.output_tokens @@ -103,7 +103,7 @@ def run_eval(self, args): questions_file_list = args.questions_file prompt_file_list = args.prompt_file output_file_list = args.output_file - + for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): @@ -113,23 +113,26 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) question_query_df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) input_rows = question_query_df.to_dict("records") output_rows = [] - + total_tried = 0 total_correct = 0 - + with ThreadPoolExecutor(args.parallel_threads) as executor: futures = [] for row in input_rows: - futures.append(executor.submit( - self.process_row, - row=row, - model_name=args.model, - args=args - )) + futures.append( + executor.submit( + self.process_row, row=row, model_name=args.model, args=args + ) + ) with tqdm(as_completed(futures), total=len(futures)) as pbar: for f in pbar: @@ -144,10 +147,12 @@ def run_eval(self, args): # save results to csv output_df = pd.DataFrame(output_rows) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) if "prompt" in output_df.columns: del output_df["prompt"] - + # get num rows, mean correct, mean error_db_exec for each query_category agg_stats = ( output_df.groupby("query_category") @@ -159,7 +164,7 @@ def run_eval(self, args): .reset_index() ) print(agg_stats) - + # get directory of output_file and create if not exist output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -184,6 +189,7 @@ def run_eval(self, args): args=args, ) + def run_gemini_eval(args): runner = GeminiRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/hf_runner.py b/runners/hf_runner.py index 1e70b6b..4dc9784 100644 --- a/runners/hf_runner.py +++ b/runners/hf_runner.py @@ -28,14 +28,18 @@ def __init__(self): self.model = None self.pipe = None - def _initialize_model(self, model_name: Optional[str], adapter_path: Optional[str], batch_size: int): + def _initialize_model( + self, model_name: Optional[str], adapter_path: Optional[str], batch_size: int + ): """Load a HuggingFace tokenizer and model.""" if adapter_path is not None: from peft import PeftModel, PeftConfig print(f"Loading adapter model {adapter_path}") config = PeftConfig.from_pretrained(adapter_path) - self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) + self.tokenizer = AutoTokenizer.from_pretrained( + config.base_model_name_or_path + ) self.model = AutoModelForCausalLM.from_pretrained( config.base_model_name_or_path, torch_dtype=torch.float16, @@ -57,7 +61,7 @@ def _initialize_model(self, model_name: Optional[str], adapter_path: Optional[st ) self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - + if model_name and "8b" in model_name.lower(): # do this since it doesn't seem to have been done by default self.tokenizer.padding_side = "left" @@ -69,10 +73,13 @@ def _initialize_model(self, model_name: Optional[str], adapter_path: Optional[st trust_remote_code=True, device_map=device_map, ) - + self.model.tie_weights() self.pipe = pipeline( - "text-generation", model=self.model, tokenizer=self.tokenizer, batch_size=batch_size + "text-generation", + model=self.model, + tokenizer=self.tokenizer, + batch_size=batch_size, ) def _extract_query(self, generated_text, prompt): @@ -97,7 +104,7 @@ def _process_batch(self, batch, args): temperature=None, top_p=None, ) - + # Clean up GPU memory gc.collect() if torch.cuda.is_available(): @@ -109,7 +116,7 @@ def _process_batch(self, batch, args): generated_query = self._extract_query( result[0]["generated_text"], row["prompt"] ) - + # More GPU cleanup gc.collect() if torch.cuda.is_available(): @@ -117,7 +124,9 @@ def _process_batch(self, batch, args): torch.cuda.synchronize() row["generated_query"] = generated_query - row["latency_seconds"] = None # HF pipeline doesn't provide per-item latency + row["latency_seconds"] = ( + None # HF pipeline doesn't provide per-item latency + ) try: exact_match, correct = compare_query_results( @@ -162,7 +171,11 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) # Create prompts for all questions df["prompt"] = df.apply( @@ -203,7 +216,7 @@ def chunk_dataframe(df, chunk_size): for batch in df_chunks: batch_results = self._process_batch(batch, args) all_results.extend(batch_results) - + # Update progress stats batch_correct = sum(1 for r in batch_results if r.get("correct", 0)) total_correct += batch_correct @@ -217,10 +230,16 @@ def chunk_dataframe(df, chunk_size): results_df = pd.DataFrame(all_results) if "prompt" in results_df.columns: del results_df["prompt"] - - print(results_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - results_df = results_df.sort_values(by=["db_name", "query_category", "question"]) - + + print( + results_df.groupby("query_category")[ + ["correct", "error_db_exec"] + ].mean() + ) + results_df = results_df.sort_values( + by=["db_name", "query_category", "question"] + ) + # Save to file output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -243,4 +262,4 @@ def chunk_dataframe(df, chunk_size): def run_hf_eval(args): runner = HFRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/llama_cpp_runner.py b/runners/llama_cpp_runner.py index f64d7c3..bdbb55b 100644 --- a/runners/llama_cpp_runner.py +++ b/runners/llama_cpp_runner.py @@ -32,6 +32,7 @@ def _call_llm(self, prompt, model_name, temperature=0.0): echo=False, repeat_penalty=1.0, ) + # Create a response object similar to other runners class LlamaResponse: def __init__(self, content): @@ -62,9 +63,13 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) - + # Create prompts for all questions df["prompt"] = df.apply( lambda row: generate_prompt( @@ -99,7 +104,9 @@ def run_eval(self, args): for row in df.to_dict("records"): try: start_time = time() - result = self.process_row(row, None, args) # None as model_name is unused + result = self.process_row( + row, None, args + ) # None as model_name is unused if result.get("correct", 0): total_correct += 1 total_tried += 1 @@ -118,10 +125,14 @@ def run_eval(self, args): output_df = pd.DataFrame(output_rows) if "prompt" in output_df.columns: del output_df["prompt"] - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - + + print( + output_df.groupby("query_category")[["correct", "error_db_exec"]].mean() + ) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) + # Save to file output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -147,4 +158,4 @@ def run_eval(self, args): def run_llama_cpp_eval(args): runner = LlamaCppRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/mistral_runner.py b/runners/mistral_runner.py index 01f6b64..fc159c2 100644 --- a/runners/mistral_runner.py +++ b/runners/mistral_runner.py @@ -29,7 +29,7 @@ def generate_prompt(self, prompt_file, **kwargs): # Check that System and User prompts are in the prompt file if "System:" not in prompt or "User:" not in prompt: raise ValueError("Invalid prompt file. Please use prompt_mistral.md") - + sys_prompt = prompt.split("System:")[1].split("User:")[0].strip() user_prompt = prompt.split("User:")[1].strip() @@ -73,11 +73,14 @@ def _call_llm(self, messages, model_name, temperature=0.0): temperature=temperature, max_tokens=600, ) + # Create a response object similar to other runners class MistralResponse: def __init__(self, content): self.content = content - self.input_tokens = 0 # Mistral doesn't provide token counts in this way + self.input_tokens = ( + 0 # Mistral doesn't provide token counts in this way + ) self.output_tokens = 0 return MistralResponse(chat_response.choices[0].message.content) @@ -85,7 +88,7 @@ def __init__(self, content): def _extract_query(self, response_content): """Extract SQL query from response with Mistral-specific handling.""" try: - # Replace backslashes + # Replace backslashes content = response_content.replace("\\", "") # First try to extract from SQL code blocks query = content.split(";")[0].split("```sql")[-1].strip() @@ -121,7 +124,7 @@ def process_row(self, row, model_name, args): result = { "generated_query": generated_query, "latency_seconds": time() - start_time, - "tokens_used": response.input_tokens + response.output_tokens + "tokens_used": response.input_tokens + response.output_tokens, } # Run comparison @@ -155,7 +158,7 @@ def process_row(self, row, model_name, args): "error_db_exec": 1, "error_msg": f"PROCESSING ERROR: {str(e)}", "latency_seconds": time() - start_time, - "tokens_used": 0 + "tokens_used": 0, } def run_eval(self, args): @@ -172,13 +175,17 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) # Process all rows output_rows = [] total_tried = total_correct = 0 - + with ThreadPoolExecutor(args.parallel_threads) as executor: futures = [] for row in df.to_dict("records"): @@ -201,10 +208,14 @@ def run_eval(self, args): output_df = pd.DataFrame(output_rows) if "prompt" in output_df.columns: del output_df["prompt"] - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - + + print( + output_df.groupby("query_category")[["correct", "error_db_exec"]].mean() + ) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) + # Save to file output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -230,4 +241,4 @@ def run_eval(self, args): def run_mistral_eval(args): runner = MistralRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/mlx_runner.py b/runners/mlx_runner.py index 4fda1ce..2e1d066 100644 --- a/runners/mlx_runner.py +++ b/runners/mlx_runner.py @@ -26,13 +26,14 @@ def _call_llm(self, prompt, model_name, temperature=0.0): """Call MLX model.""" # model_name is unused, as it's already loaded in _initialize_model text = generate( - self.model, - self.tokenizer, - prompt=prompt, - max_tokens=512, - temp=0, - verbose=True + self.model, + self.tokenizer, + prompt=prompt, + max_tokens=512, + temp=0, + verbose=True, ) + # Create a response object similar to other runners class MLXResponse: def __init__(self, content): @@ -63,9 +64,13 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) - + # Create prompts for all questions df["prompt"] = df.apply( lambda row: generate_prompt( @@ -99,7 +104,9 @@ def run_eval(self, args): with tqdm(total=len(df)) as pbar: for row in df.to_dict("records"): try: - result = self.process_row(row, None, args) # None as model_name is unused + result = self.process_row( + row, None, args + ) # None as model_name is unused if result.get("correct", 0): total_correct += 1 total_tried += 1 @@ -118,10 +125,14 @@ def run_eval(self, args): output_df = pd.DataFrame(output_rows) if "prompt" in output_df.columns: del output_df["prompt"] - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - + + print( + output_df.groupby("query_category")[["correct", "error_db_exec"]].mean() + ) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) + # Save to file output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -147,4 +158,4 @@ def run_eval(self, args): def run_mlx_eval(args): runner = MLXRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/openai_runner.py b/runners/openai_runner.py index c31213c..af324d0 100644 --- a/runners/openai_runner.py +++ b/runners/openai_runner.py @@ -27,7 +27,7 @@ def run_eval(self, args): questions_file_list = args.questions_file prompt_file_list = args.prompt_file output_file_list = args.output_file - + for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): @@ -37,11 +37,15 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) question_query_df = self.prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) input_rows = question_query_df.to_dict("records") output_rows = [] - + with ThreadPoolExecutor(args.parallel_threads) as executor: futures = [] for row in input_rows: @@ -105,7 +109,9 @@ def run_eval(self, args): # save results to csv output_df = pd.DataFrame(output_rows) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) if "prompt" in output_df.columns: del output_df["prompt"] # get num rows, mean correct, mean error_db_exec for each query_category @@ -143,6 +149,7 @@ def run_eval(self, args): args=args, ) + def run_openai_eval(args): runner = OpenAIRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/together_runner.py b/runners/together_runner.py index 9b726f9..481dd28 100644 --- a/runners/together_runner.py +++ b/runners/together_runner.py @@ -24,7 +24,9 @@ def _get_stop_tokens(self, model_name: str): if model_name.startswith("meta-llama"): return ["<|eot_id|>", "<|eom_id|>"] else: - print("Undefined stop token(s). Please specify the stop token(s) for the model.") + print( + "Undefined stop token(s). Please specify the stop token(s) for the model." + ) return [] def _call_llm(self, messages, model_name, temperature=0.0): @@ -38,7 +40,7 @@ def _call_llm(self, messages, model_name, temperature=0.0): stop=stop_tokens, stream=False, ) - + # Create a response object similar to other runners class TogetherResponse: def __init__(self, content): @@ -86,7 +88,7 @@ def process_row(self, row: Dict, model_name, args): result = { "generated_query": generated_query, "latency_seconds": time() - start_time, - "tokens_used": None # Together doesn't provide token counts + "tokens_used": None, # Together doesn't provide token counts } # Run comparison @@ -119,14 +121,16 @@ def process_row(self, row: Dict, model_name, args): "error_db_exec": 1, "error_msg": f"PROCESSING ERROR: {str(e)}", "latency_seconds": time() - start_time, - "tokens_used": None + "tokens_used": None, } def run_eval(self, args): """Together-specific evaluation logic.""" if not args.prompt_file[0].endswith(".json"): - raise ValueError(f"Prompt file must be a JSON file. Got {args.prompt_file[0]}") - + raise ValueError( + f"Prompt file must be a JSON file. Got {args.prompt_file[0]}" + ) + for questions_file, prompt_file, output_file in zip( args.questions_file, args.prompt_file, args.output_file ): @@ -136,13 +140,17 @@ def run_eval(self, args): f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" ) df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) # Process all rows output_rows = [] total_tried = total_correct = 0 - + with ThreadPoolExecutor(args.parallel_threads) as executor: futures = [] for row in df.to_dict("records"): @@ -165,10 +173,14 @@ def run_eval(self, args): output_df = pd.DataFrame(output_rows) if "prompt" in output_df.columns: del output_df["prompt"] - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - + + print( + output_df.groupby("query_category")[["correct", "error_db_exec"]].mean() + ) + output_df = output_df.sort_values( + by=["db_name", "query_category", "question"] + ) + # Save to file output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -194,4 +206,4 @@ def run_eval(self, args): def run_together_eval(args): runner = TogetherRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args) diff --git a/runners/vllm_runner.py b/runners/vllm_runner.py index 95527d2..f58af50 100644 --- a/runners/vllm_runner.py +++ b/runners/vllm_runner.py @@ -28,13 +28,15 @@ def _initialize_model(self, args): """Initialize VLLM model with appropriate parameters.""" model_name = args.model enable_lora = True if args.adapter else False - lora_request = LoRARequest("sql_adapter", 1, args.adapter) if args.adapter else None + lora_request = ( + LoRARequest("sql_adapter", 1, args.adapter) if args.adapter else None + ) self.lora_request = lora_request print(f"Preparing {model_name}") self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - + if not args.quantized: self.llm = LLM( model=model_name, @@ -76,7 +78,7 @@ def _process_batch(self, batch, args): """Process a batch of questions using VLLM.""" prompts = batch["prompt"].tolist() print(f"Generating completions for {len(prompts)} prompts") - + # Tokenize prompts prompt_tokens = [] prompt_token_sizes = [] @@ -86,9 +88,11 @@ def _process_batch(self, batch, args): token_ids = [self.tokenizer.bos_token_id] + token_ids prompt_tokens.append(token_ids) prompt_token_sizes.append(len(token_ids)) - - print(f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}") - + + print( + f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}" + ) + # Generate completions start_time = time.time() outputs = self.llm.generate( @@ -107,7 +111,7 @@ def _process_batch(self, batch, args): generated_query = sqlparse.format( generated_query, keyword_case="upper", strip_whitespace=True ) - + row["generated_query"] = generated_query row["tokens_used"] = len(output.outputs[0].token_ids) row["latency_seconds"] = time_taken / len(batch) @@ -143,9 +147,15 @@ def run_eval(self, args): ): print(f"Using prompt file {prompt_file}") print("Preparing questions...") - print(f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}") + print( + f"Using {'all' if args.num_questions is None else args.num_questions} question(s) from {questions_file}" + ) df = prepare_questions_df( - questions_file, args.db_type, args.num_questions, args.k_shot, args.cot_table_alias + questions_file, + args.db_type, + args.num_questions, + args.k_shot, + args.cot_table_alias, ) # Create prompts for all questions @@ -180,7 +190,9 @@ def chunk_dataframe(df, chunk_size): df_chunks = [] for i in range(0, len(df), chunk_size): df_i = df.iloc[i : min(i + chunk_size, len(df))] - print(f"Chunk {i//chunk_size+1}/{len(df)//chunk_size+1} with {len(df_i)} questions") + print( + f"Chunk {i//chunk_size+1}/{len(df)//chunk_size+1} with {len(df_i)} questions" + ) df_chunks.append(df_i) return df_chunks @@ -192,7 +204,7 @@ def chunk_dataframe(df, chunk_size): for batch in (pbar := tqdm(df_chunks, total=len(df_chunks))): batch_results = self._process_batch(batch, args) all_results.extend(batch_results) - + # Update progress stats batch_correct = sum(1 for r in batch_results if r.get("correct", 0)) total_correct += batch_correct @@ -205,11 +217,15 @@ def chunk_dataframe(df, chunk_size): results_df = pd.DataFrame(all_results) if "prompt" in results_df.columns: del results_df["prompt"] - - print(results_df.groupby("query_category")[["exact_match", "correct"]].mean()) - results_df = results_df.sort_values(by=["db_name", "query_category", "question"]) + + print( + results_df.groupby("query_category")[["exact_match", "correct"]].mean() + ) + results_df = results_df.sort_values( + by=["db_name", "query_category", "question"] + ) print(f"Average tokens generated: {results_df['tokens_used'].mean():.1f}") - + # Save to file output_dir = os.path.dirname(output_file) if not os.path.exists(output_dir): @@ -232,4 +248,4 @@ def chunk_dataframe(df, chunk_size): def run_vllm_eval(args): runner = VLLMRunner() - runner.run_eval(args) \ No newline at end of file + runner.run_eval(args)