|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import pandas as pd |
| 5 | +import json |
| 6 | + |
| 7 | +from data_formulator.agents.agent_utils import extract_code_from_gpt_response, extract_json_objects |
| 8 | +import re |
| 9 | +import logging |
| 10 | + |
| 11 | + |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +SYSTEM_PROMPT = '''You are a data scientist to help with data queries. |
| 16 | +The user will provide you with a description of the data source and tables available in the [DATA SOURCE] section and a query in the [USER INPUTS] section. |
| 17 | +You will need to help the user complete the query and provide reasoning for the query you generated in the [OUTPUT] section. |
| 18 | +
|
| 19 | +Input format: |
| 20 | +* The data source description is a json object with the following fields: |
| 21 | + * `data_source`: the name of the data source |
| 22 | + * `tables`: a list of tables in the data source, which maps the table name to the list of columns available in the table. |
| 23 | +* The user input is a natural language description of the query or a partial query you need to complete. |
| 24 | +
|
| 25 | +Steps: |
| 26 | +* Based on data source description and user input, you should first decide on what language should be used to query the data. |
| 27 | +* Then, describe the logic for the query you generated in a json object in a block ```json``` with the following fields: |
| 28 | + * `language`: the language of the query you generated |
| 29 | + * `tables`: the names of the tables you will use in the query |
| 30 | + * `logic`: the reasoning behind why you chose the tables and the logic for the query you generated |
| 31 | +* Finally, generate the complete query in the language specified in a code block ```{language}```. |
| 32 | +
|
| 33 | +Output format: |
| 34 | +* The output should be in the following format, no other text should be included: |
| 35 | +
|
| 36 | +[REASONING] |
| 37 | +```json |
| 38 | +{ |
| 39 | + "language": {language}, |
| 40 | + "tables": {tables}, |
| 41 | + "logic": {logic} |
| 42 | +} |
| 43 | +``` |
| 44 | +
|
| 45 | +[QUERY] |
| 46 | +```{language} |
| 47 | +{query} |
| 48 | +``` |
| 49 | +''' |
| 50 | + |
| 51 | +class QueryCompletionAgent(object): |
| 52 | + |
| 53 | + def __init__(self, client): |
| 54 | + self.client = client |
| 55 | + |
| 56 | + def run(self, data_source_metadata, query): |
| 57 | + |
| 58 | + user_query = f"[DATA SOURCE]\n\n{json.dumps(data_source_metadata, indent=2)}\n\n[USER INPUTS]\n\n{query}\n\n[REASONING]\n" |
| 59 | + |
| 60 | + logger.info(user_query) |
| 61 | + |
| 62 | + messages = [{"role":"system", "content": SYSTEM_PROMPT}, |
| 63 | + {"role":"user","content": user_query}] |
| 64 | + |
| 65 | + ###### the part that calls open_ai |
| 66 | + response = self.client.get_completion(messages = messages) |
| 67 | + response_content = '[REASONING]\n' + response.choices[0].message.content |
| 68 | + |
| 69 | + logger.info(f"=== query completion output ===>\n{response_content}\n") |
| 70 | + |
| 71 | + reasoning = extract_json_objects(response_content.split("[REASONING]")[1].split("[QUERY]")[0].strip())[0] |
| 72 | + output_query = response_content.split("[QUERY]")[1].strip() |
| 73 | + |
| 74 | + # Extract the query by removing the language markers |
| 75 | + language_pattern = r"```(\w+)\s+(.*?)```" |
| 76 | + match = re.search(language_pattern, output_query, re.DOTALL) |
| 77 | + if match: |
| 78 | + output_query = match.group(2).strip() |
| 79 | + |
| 80 | + return reasoning, output_query |
0 commit comments