Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamline show table #907

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,26 @@ class SQLAgent(LumenBaseAgent):

_output_type = SQLOutput

async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource]:
async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource, bool]:
"""Select the most relevant table based on the user query."""
join_required = None
sources = self._memory["sources"]
tables_to_source, tables_schema_str = await gather_table_sources(sources)
tables = tuple(tables_to_source)
if messages and messages[-1]["content"].startswith("Show the table: '"):

user_message = ""
for message in messages[::-1]:
if message["role"] == "user":
user_message = message["content"]
break

if messages and "Show the table: " in user_message:
# Handle the case where explicitly requested a table
table = re.search(r"Show the table: '([^']+)'", messages[-1]["content"]).group(1)
table = re.search(r"Show the table: '([^']+)'", user_message).group(1)
join_required = False
elif len(tables) == 1:
table = tables[0]
join_required = False
else:
with self.interface.add_step(title="Choosing the most relevant table...", steps_layout=self._steps_layout) as step:
if len(tables) > 1:
Expand Down Expand Up @@ -499,7 +509,7 @@ async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, Ba
sources = [src for src in sources if table in src]
source = sources[0] if sources else self._memory["source"]

return table, source
return table, source, join_required

@retry_llm_output()
async def _create_valid_sql(
Expand Down Expand Up @@ -690,19 +700,20 @@ async def respond(
8. If a join is required, remove source/table prefixes from the last message.
9. Construct the SQL query with `_create_valid_sql`.
"""
table, source = await self._select_relevant_table(messages)
table, source, join_required = await self._select_relevant_table(messages)
if not hasattr(source, "get_sql_expr"):
return None

# include min max for more context for data cleaning
schema = await get_schema(source, table, include_min_max=True)
join_required = await self._check_requires_joins(messages, schema, table)

tables_to_source = {table: source}
if join_required is None:
return None
if join_required:
tables_to_source = await self.find_join_tables(messages)
else:
tables_to_source = {table: source}
join_required = await self._check_requires_joins(messages, schema, table)
if join_required is None: # TODO: what is this for? can we remove?
philippjfr marked this conversation as resolved.
Show resolved Hide resolved
return None
if join_required:
tables_to_source = await self.find_join_tables(messages)

tables_sql_schemas = {}
for source_table, source in tables_to_source.items():
Expand Down
Loading