From 03d7b25ffc90fd5ce474bbc95a2320eca6db684f Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Fri, 27 Dec 2024 12:08:40 -0800 Subject: [PATCH 1/2] streamline show --- lumen/ai/agents.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index da4bd5e4..a7431545 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -459,16 +459,26 @@ class SQLAgent(LumenBaseAgent): _output_type = SQLOutput - async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource]: + async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource, bool]: """Select the most relevant table based on the user query.""" + join_required = None sources = self._memory["sources"] tables_to_source, tables_schema_str = await gather_table_sources(sources) tables = tuple(tables_to_source) - if messages and messages[-1]["content"].startswith("Show the table: '"): + + user_message = "" + for message in messages[::-1]: + if message["role"] == "user": + user_message = message["content"] + break + + if messages and "Show the table: " in user_message: # Handle the case where explicitly requested a table - table = re.search(r"Show the table: '([^']+)'", messages[-1]["content"]).group(1) + table = re.search(r"Show the table: '([^']+)'", user_message).group(1) + join_required = False elif len(tables) == 1: table = tables[0] + join_required = False else: with self.interface.add_step(title="Choosing the most relevant table...", steps_layout=self._steps_layout) as step: if len(tables) > 1: @@ -499,7 +509,7 @@ async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, Ba sources = [src for src in sources if table in src] source = sources[0] if sources else self._memory["source"] - return table, source + return table, source, join_required @retry_llm_output() async def _create_valid_sql( @@ -690,19 +700,20 @@ async def respond( 8. If a join is required, remove source/table prefixes from the last message. 9. Construct the SQL query with `_create_valid_sql`. """ - table, source = await self._select_relevant_table(messages) + table, source, join_required = await self._select_relevant_table(messages) if not hasattr(source, "get_sql_expr"): return None # include min max for more context for data cleaning schema = await get_schema(source, table, include_min_max=True) - join_required = await self._check_requires_joins(messages, schema, table) + + tables_to_source = {table: source} if join_required is None: - return None - if join_required: - tables_to_source = await self.find_join_tables(messages) - else: - tables_to_source = {table: source} + join_required = await self._check_requires_joins(messages, schema, table) + if join_required is None: # TODO: what is this for? can we remove? + return None + if join_required: + tables_to_source = await self.find_join_tables(messages) tables_sql_schemas = {} for source_table, source in tables_to_source.items(): From f45e0dd4c706586f5e0fbba6352853e28f9377a6 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Mon, 30 Dec 2024 12:30:39 +0100 Subject: [PATCH 2/2] Adjust comment --- lumen/ai/agents.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index a7431545..cf5513c1 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -710,7 +710,8 @@ async def respond( tables_to_source = {table: source} if join_required is None: join_required = await self._check_requires_joins(messages, schema, table) - if join_required is None: # TODO: what is this for? can we remove? + if join_required is None: + # Bail if query was cancelled or errored out return None if join_required: tables_to_source = await self.find_join_tables(messages)