Skip to content

Commit

Permalink
i_am_nitpicky.exe
Browse files Browse the repository at this point in the history
  • Loading branch information
anthony2261 committed Aug 28, 2024
1 parent e68d5d6 commit 15fd324
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions backend/dataline/services/llm_flow/toolkit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import json
import operator
from typing import Annotated, Any, List, Optional, Sequence, Type, TypedDict, cast
from typing import Annotated, Any, Iterable, List, Optional, Sequence, Type, TypedDict, cast

from dataline.models.llm_flow.schema import (
ChartGenerationResult,
Expand Down Expand Up @@ -40,6 +40,9 @@ def __init__(self, message: str):
class ChartValidationRunException(RunException): ...


class TableNotFoundException(RunException): ...


def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: # type: ignore[misc]
"""
Truncate a string to a certain number of words, based on the max string
Expand Down Expand Up @@ -166,7 +169,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, StateUpdaterTool):
# Pydantic model to validate input to the tool
args_schema: Type[BaseModelV1] = _InfoSQLDatabaseToolInput

def _validate_table_names(self, table_names: str, available_names: list[str]) -> tuple[set[str], list[str]]:
def _validate_sanitize_table_names(self, table_names: str, available_names: Iterable[str]) -> set[str]:
"""Validate table names and return valid and invalid tables."""
cleaned_names = [table_name.strip() for table_name in table_names.split(",")]
available_names_tables_only = {name.split(".")[-1]: name for name in available_names}
Expand All @@ -182,7 +185,13 @@ def _validate_table_names(self, table_names: str, available_names: list[str]) ->
else:
wrong_tables.append(name)

return valid_tables, wrong_tables
if wrong_tables:
raise TableNotFoundException(
f"""ERROR: Tables {wrong_tables} that you selected do not exist in the database.
Available tables are the following, please select from them ONLY: "{'", "'.join(available_names)}"."""
)

return valid_tables

def _run(
self,
Expand All @@ -193,11 +202,7 @@ def _run(
self.table_names = None # Reset internal state in case it remains from tool calls

available_names = self.db.get_usable_table_names()
valid_tables, wrong_tables = self._validate_table_names(table_names, available_names)

if wrong_tables:
return f"""ERROR: Tables {wrong_tables} that you selected do not exist in the database.
Available tables are the following, please select from them ONLY: "{'", "'.join(available_names)}"."""
valid_tables = self._validate_sanitize_table_names(table_names, available_names)

self.table_names = list(valid_tables)
return self.db.get_table_info_no_throw(self.table_names)
Expand All @@ -211,8 +216,14 @@ def get_response( # type: ignore[misc]
messages: list[BaseMessage] = []
results: list[QueryResultSchema] = []

# We call the tool_executor and get back a response
response = self.run(args)
try:
# We call the tool_executor and get back a response
response = self.run(args)
except TableNotFoundException as e:
tool_message = ToolMessage(content=str(e.message), name=self.name, tool_call_id=call_id)
messages.append(tool_message)
return state_update(messages=messages)

# We use the response to create a ToolMessage
tool_message = ToolMessage(content=str(response), name=self.name, tool_call_id=call_id)
messages.append(tool_message)
Expand All @@ -221,10 +232,7 @@ def get_response( # type: ignore[misc]
if self.table_names:
results.append(SelectedTablesResult(tables=self.table_names))

return {
"messages": messages,
"results": results,
}
return state_update(messages=messages, results=results)


class _QuerySQLDataBaseToolInput(BaseModelV1):
Expand Down

0 comments on commit 15fd324

Please sign in to comment.