Skip to content

Commit

Permalink
Merge pull request #308 from joehaddad2000/main
Browse files Browse the repository at this point in the history
Improve table name validation in InfoSQLDatabaseTool
  • Loading branch information
anthony2261 authored Aug 28, 2024
2 parents 25e700e + 15fd324 commit ae27d79
Showing 1 changed file with 40 additions and 19 deletions.
59 changes: 40 additions & 19 deletions backend/dataline/services/llm_flow/toolkit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import json
import operator
from typing import Annotated, Any, List, Optional, Sequence, Type, TypedDict, cast
from typing import Annotated, Any, Iterable, List, Optional, Sequence, Type, TypedDict, cast

from dataline.models.llm_flow.schema import (
ChartGenerationResult,
Expand Down Expand Up @@ -40,6 +40,9 @@ def __init__(self, message: str):
class ChartValidationRunException(RunException): ...


class TableNotFoundException(RunException): ...


def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: # type: ignore[misc]
"""
Truncate a string to a certain number of words, based on the max string
Expand Down Expand Up @@ -166,27 +169,42 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, StateUpdaterTool):
# Pydantic model to validate input to the tool
args_schema: Type[BaseModelV1] = _InfoSQLDatabaseToolInput

def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
self.table_names = None # Reset internal state in case it remains from tool calls
def _validate_sanitize_table_names(self, table_names: str, available_names: Iterable[str]) -> set[str]:
"""Validate table names and return valid and invalid tables."""
cleaned_names = [table_name.strip() for table_name in table_names.split(",")]
available_names = self.db.get_usable_table_names()
available_names_tables_only = {name.split(".")[-1]: name for name in available_names}

# Check if the table names are valid
valid_tables = set()
wrong_tables = []

for name in cleaned_names:
if name not in available_names:
if name in available_names:
valid_tables.add(name)
elif name in available_names_tables_only:
valid_tables.add(available_names_tables_only[name])
else:
wrong_tables.append(name)

if wrong_tables:
return f"""ERROR: Tables {wrong_tables} that you selected do not exist in the database.
raise TableNotFoundException(
f"""ERROR: Tables {wrong_tables} that you selected do not exist in the database.
Available tables are the following, please select from them ONLY: "{'", "'.join(available_names)}"."""
)

return valid_tables

def _run(
self,
table_names: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the schema for tables in a comma-separated list."""
self.table_names = None # Reset internal state in case it remains from tool calls

available_names = self.db.get_usable_table_names()
valid_tables = self._validate_sanitize_table_names(table_names, available_names)

self.table_names = [t.strip() for t in table_names.split(",")]
self.table_names = list(valid_tables)
return self.db.get_table_info_no_throw(self.table_names)

def get_response( # type: ignore[misc]
Expand All @@ -198,8 +216,14 @@ def get_response( # type: ignore[misc]
messages: list[BaseMessage] = []
results: list[QueryResultSchema] = []

# We call the tool_executor and get back a response
response = self.run(args)
try:
# We call the tool_executor and get back a response
response = self.run(args)
except TableNotFoundException as e:
tool_message = ToolMessage(content=str(e.message), name=self.name, tool_call_id=call_id)
messages.append(tool_message)
return state_update(messages=messages)

# We use the response to create a ToolMessage
tool_message = ToolMessage(content=str(response), name=self.name, tool_call_id=call_id)
messages.append(tool_message)
Expand All @@ -208,10 +232,7 @@ def get_response( # type: ignore[misc]
if self.table_names:
results.append(SelectedTablesResult(tables=self.table_names))

return {
"messages": messages,
"results": results,
}
return state_update(messages=messages, results=results)


class _QuerySQLDataBaseToolInput(BaseModelV1):
Expand Down

0 comments on commit ae27d79

Please sign in to comment.