Skip to content

Commit

Permalink
Add support for SQL bind parameters, closes #564
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Nov 7, 2023
1 parent a4541ad commit 7b25468
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 31 deletions.
10 changes: 6 additions & 4 deletions src/python/txtai/database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def count(self):

raise NotImplementedError

def search(self, query, similarity=None, limit=None):
def search(self, query, similarity=None, limit=None, parameters=None):
"""
Runs a search against the database. Supports the following methods:
Expand All @@ -133,6 +133,7 @@ def search(self, query, similarity=None, limit=None):
query: input query
similarity: similarity results as [(indexid, score)]
limit: maximum number of results to return
parameters: dict of named parameters to bind to placeholders
Returns:
query results as a list of dicts
Expand All @@ -159,7 +160,7 @@ def search(self, query, similarity=None, limit=None):
query["where"] = where

# Run query
return self.query(query, limit)
return self.query(query, limit, parameters)

def parse(self, query):
"""
Expand Down Expand Up @@ -200,13 +201,14 @@ def embed(self, similarity, batch):

raise NotImplementedError

def query(self, query, limit):
def query(self, query, limit, parameters):
"""
Executes query against database.
Args:
query: input query
limit: maximum number of results to return
parameters: dict of named parameters to bind to placeholders
Returns:
query results
Expand Down Expand Up @@ -312,7 +314,7 @@ def execute(self, function, *args):

try:
# Debug log SQL
logger.debug(*args)
logger.debug(" ".join(["%s"] * len(args)), *args)

return function(*args)
except Exception as ex:
Expand Down
9 changes: 5 additions & 4 deletions src/python/txtai/database/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def jsoncolumn(self, name):
d = aliased(Document, name="d")

# Build JSON column expression for column
return str(cast(d.data[name], Text).compile(dialect=self.connection.bind.dialect, compile_kwargs={"literal_binds": True}))
return str(cast(d.data[name].as_string(), Text).compile(dialect=self.connection.bind.dialect, compile_kwargs={"literal_binds": True}))

def createtables(self):
# Create tables
Expand Down Expand Up @@ -111,7 +111,7 @@ def connect(self, path=None):
content = os.environ.get("CLIENT_URL") if content == "client" else content

# Create engine using database URL
engine = create_engine(content, poolclass=StaticPool, echo=False)
engine = create_engine(content, poolclass=StaticPool, echo=False, json_serializer=lambda x: x)

# Create database session
return Session(engine)
Expand All @@ -138,18 +138,19 @@ def __init__(self, connection):
def __iter__(self):
return self.result

def execute(self, statement):
def execute(self, statement, parameters=None):
"""
Executes statement.
Args:
statement: statement to execute
parameters: optional dictionary with bind parameters
"""

if isinstance(statement, str):
statement = textsql(statement)

self.result = self.connection.execute(statement)
self.result = self.connection.execute(statement, parameters)

def fetchall(self):
"""
Expand Down
49 changes: 47 additions & 2 deletions src/python/txtai/database/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import os
import re

from tempfile import TemporaryDirectory

Expand Down Expand Up @@ -33,18 +34,22 @@ def __init__(self, config):
if not DUCKDB:
raise ImportError('DuckDB is not available - install "database" extra to enable')

def execute(self, function, *args):
# Call parent method with DuckDB compatible arguments
return super().execute(function, *self.formatargs(args))

def insertdocument(self, uid, data, tags, entry):
# Delete existing document
self.cursor.execute(DuckDB.DELETE_DOCUMENT, [uid])

# Call parent logic
# Call parent method
super().insertdocument(uid, data, tags, entry)

def insertobject(self, uid, data, tags, entry):
# Delete existing object
self.cursor.execute(DuckDB.DELETE_OBJECT, [uid])

# Call parent logic
# Call parent method
super().insertobject(uid, data, tags, entry)

def connect(self, path=":memory:"):
Expand All @@ -58,6 +63,14 @@ def connect(self, path=":memory:"):
def getcursor(self):
return self.connection

def jsonprefix(self):
# Return json column prefix
return "json_extract_string(data"

def jsoncolumn(self, name):
# Generate json column using json_extract function
return f"json_extract_string(data, '$.{name}')"

def rows(self):
# Iteratively retrieve and yield rows
batch = 256
Expand Down Expand Up @@ -103,3 +116,35 @@ def copy(self, path):
connection.begin()

return connection

def formatargs(self, args):
"""
DuckDB doesn't support named parameters. This method replaces named parameters with question marks
and makes parameters a list.
Args:
args: input arguments
Returns:
DuckDB compatible args
"""

if args and len(args) > 1:
# Unpack query args
query, parameters = args

# Iterate over parameters
# - Replace named parameters with ?'s
# - Build list of value with position indexes
params = []
for key, value in parameters.items():
pattern = rf"\:{key}(?=\s|$)"
match = re.search(pattern, query)
if match:
query = re.sub(pattern, "?", query, count=1)
params.append((match.start(), value))

# Repack query and parameter list
args = (query, [value for _, value in sorted(params, key=lambda x: x[0])])

return args
5 changes: 3 additions & 2 deletions src/python/txtai/database/rdbms.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def embed(self, similarity, batch):
return Statement.IDS_CLAUSE % batch

# pylint: disable=R0912
def query(self, query, limit):
def query(self, query, limit, parameters):
# Extract query components
select = query.get("select", self.defaults())
where = query.get("where")
Expand Down Expand Up @@ -214,7 +214,8 @@ def query(self, query, limit):
self.scores(None)

# Runs a user query through execute method, which has common user query handling logic
self.execute(self.cursor.execute, query)
args = (query, parameters) if parameters else (query,)
self.execute(self.cursor.execute, *args)

# Retrieve column list from query
columns = [c[0] for c in self.cursor.description]
Expand Down
5 changes: 3 additions & 2 deletions src/python/txtai/database/sql/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def compound(self, iterator, tokens, x, aliases, similar):

def resolve(self, token, aliases):
"""
Resolves this token's value if it is not an alias.
Resolves this token's value if it is not an alias or a bind parameter.
Args:
token: token to resolve
Expand All @@ -397,7 +397,8 @@ def resolve(self, token, aliases):
resolved token value
"""

if aliases and Token.normalize(token) in aliases:
# Check for alias or bind parameter
if (aliases and Token.normalize(token) in aliases) or (token.startswith(":")):
return token

return self.resolver(token)
10 changes: 6 additions & 4 deletions src/python/txtai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def count(self):
# Default to 0 when no suitable method found
return 0

def search(self, query, limit=None, weights=None, index=None):
def search(self, query, limit=None, weights=None, index=None, parameters=None):
"""
Finds documents most similar to the input query. This method will run either an index search
or an index + database search depending on if a database is available.
Expand All @@ -354,15 +354,16 @@ def search(self, query, limit=None, weights=None, index=None):
limit: maximum results
weights: hybrid score weights, if applicable
index: index name, if applicable
parameters: dict of named parameters to bind to placeholders
Returns:
list of (id, score) for index search, list of dict for an index + database search
"""

results = self.batchsearch([query], limit, weights, index)
results = self.batchsearch([query], limit, weights, index, [parameters])
return results[0] if results else results

def batchsearch(self, queries, limit=None, weights=None, index=None):
def batchsearch(self, queries, limit=None, weights=None, index=None, parameters=None):
"""
Finds documents most similar to the input queries. This method will run either an index search
or an index + database search depending on if a database is available.
Expand All @@ -372,12 +373,13 @@ def batchsearch(self, queries, limit=None, weights=None, index=None):
limit: maximum results
weights: hybrid score weights, if applicable
index: index name, if applicable
parameters: list of dicts of named parameters to bind to placeholders
Returns:
list of (id, score) per query for index search, list of dict per query for an index + database search
"""

return Search(self)(queries, limit, weights, index)
return Search(self)(queries, limit, weights, index, parameters)

def similarity(self, query, data):
"""
Expand Down
12 changes: 7 additions & 5 deletions src/python/txtai/embeddings/search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, embeddings, indexids=False):
self.query = embeddings.query
self.scoring = embeddings.scoring if embeddings.issparse() else None

def __call__(self, queries, limit=None, weights=None, index=None):
def __call__(self, queries, limit=None, weights=None, index=None, parameters=None):
"""
Executes a batch search for queries. This method will run either an index search or an index + database search
depending on if a database is available.
Expand All @@ -47,6 +47,7 @@ def __call__(self, queries, limit=None, weights=None, index=None):
limit: maximum results
weights: hybrid score weights
index: index name
parameters: list of dicts of named parameters to bind to placeholders
Returns:
list of (id, score) per query for index search, list of dict per query for an index + database search
Expand All @@ -66,7 +67,7 @@ def __call__(self, queries, limit=None, weights=None, index=None):

# Database search
if not self.indexids and self.database:
return self.dbsearch(queries, limit, weights, index)
return self.dbsearch(queries, limit, weights, index, parameters)

# Default vector index query (sparse, dense or hybrid)
return self.search(queries, limit, weights, index)
Expand Down Expand Up @@ -209,7 +210,7 @@ def resolve(self, results):

return results

def dbsearch(self, queries, limit, weights, index):
def dbsearch(self, queries, limit, weights, index, parameters):
"""
Executes an index + database search.
Expand All @@ -218,6 +219,7 @@ def dbsearch(self, queries, limit, weights, index):
limit: maximum results
weights: default hybrid score weights
index: default index name
parameters: list of dicts of named parameters to bind to placeholders
Returns:
list of dict per query
Expand All @@ -230,13 +232,13 @@ def dbsearch(self, queries, limit, weights, index):
limit = max(limit, self.limit(queries))

# Bulk index scan
scan = Scan(self.search, limit, weights, index)(queries)
scan = Scan(self.search, limit, weights, index)(queries, parameters)

# Combine index search results with database search results
results = []
for x, query in enumerate(queries):
# Run the database query, get matching bulk searches for current query
result = self.database.search(query, [r for y, r in scan if x == y], limit)
result = self.database.search(query, [r for y, r in scan if x == y], limit, parameters[x] if parameters and parameters[x] else None)
results.append(result)

return results
Expand Down
Loading

0 comments on commit 7b25468

Please sign in to comment.