diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index da4bd5e4..cf5513c1 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -459,16 +459,26 @@ class SQLAgent(LumenBaseAgent): _output_type = SQLOutput - async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource]: + async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource, bool]: """Select the most relevant table based on the user query.""" + join_required = None sources = self._memory["sources"] tables_to_source, tables_schema_str = await gather_table_sources(sources) tables = tuple(tables_to_source) - if messages and messages[-1]["content"].startswith("Show the table: '"): + + user_message = "" + for message in messages[::-1]: + if message["role"] == "user": + user_message = message["content"] + break + + if messages and "Show the table: " in user_message: # Handle the case where explicitly requested a table - table = re.search(r"Show the table: '([^']+)'", messages[-1]["content"]).group(1) + table = re.search(r"Show the table: '([^']+)'", user_message).group(1) + join_required = False elif len(tables) == 1: table = tables[0] + join_required = False else: with self.interface.add_step(title="Choosing the most relevant table...", steps_layout=self._steps_layout) as step: if len(tables) > 1: @@ -499,7 +509,7 @@ async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, Ba sources = [src for src in sources if table in src] source = sources[0] if sources else self._memory["source"] - return table, source + return table, source, join_required @retry_llm_output() async def _create_valid_sql( @@ -690,19 +700,21 @@ async def respond( 8. If a join is required, remove source/table prefixes from the last message. 9. Construct the SQL query with `_create_valid_sql`. """ - table, source = await self._select_relevant_table(messages) + table, source, join_required = await self._select_relevant_table(messages) if not hasattr(source, "get_sql_expr"): return None # include min max for more context for data cleaning schema = await get_schema(source, table, include_min_max=True) - join_required = await self._check_requires_joins(messages, schema, table) + + tables_to_source = {table: source} if join_required is None: - return None - if join_required: - tables_to_source = await self.find_join_tables(messages) - else: - tables_to_source = {table: source} + join_required = await self._check_requires_joins(messages, schema, table) + if join_required is None: + # Bail if query was cancelled or errored out + return None + if join_required: + tables_to_source = await self.find_join_tables(messages) tables_sql_schemas = {} for source_table, source in tables_to_source.items():