Skip to content

Commit

Permalink
Fix prepared statement handling
Browse files Browse the repository at this point in the history
The prepared statement handling code assumed that for each query we'll
always receive some non-empty response even after the initial response
which is not a valid assumption.

This assumption worked because earlier Trino used to send empty fake
results even for queries which don't return results (like PREPARE and
DEALLOCATE) but is now invalid with
trinodb/trino@bc794cd.

The other problem with the code was that it leaked HTTP protocol details
into dbapi.py and worked around it by keeping a deep copy of the request
object from the PREPARE execution and re-using it for the actual query
execution.

The new code fixes both issues by processing the prepared statement
headers as they are received and storing the resulting set of active
prepared statements on the ClientSession object. The ClientSession's set
of prepared statements is then rendered into the prepared statement
request header in TrinoRequest. Since the ClientSession is created and
reused for the entire Connection this also means that we can now
actually implement re-use of prepared statements within a single
Connection.
  • Loading branch information
hashhar committed Oct 3, 2022
1 parent efb6680 commit d5f779b
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 110 deletions.
35 changes: 35 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,3 +1067,38 @@ def test_set_role_trino_351(run_trino):
cur.execute("SET ROLE ALL")
cur.fetchall()
assert cur._request._client_session.role == "tpch=ALL"


def test_prepared_statements(run_trino):
_, host, port = run_trino

trino_connection = trino.dbapi.Connection(
host=host, port=port, user="test", catalog="tpch",
)
cur = trino_connection.cursor()

# Implicit prepared statements must work and deallocate statements on finish
assert cur._request._client_session.prepared_statements == {}
cur.execute('SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?', (1,))
result = cur.fetchall()
assert result[0][0] == 1
assert cur._request._client_session.prepared_statements == {}

# Explicit prepared statements must also work
cur.execute('PREPARE test_prepared_statements FROM SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?')
cur.fetchall()
assert 'test_prepared_statements' in cur._request._client_session.prepared_statements
cur.execute('EXECUTE test_prepared_statements USING 1')
cur.fetchall()
assert result[0][0] == 1

# An implicit prepared statement must not deallocate explicit statements
cur.execute('SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?', (1,))
result = cur.fetchall()
assert result[0][0] == 1
assert 'test_prepared_statements' in cur._request._client_session.prepared_statements

assert 'test_prepared_statements' in cur._request._client_session.prepared_statements
cur.execute('DEALLOCATE PREPARE test_prepared_statements')
cur.fetchall()
assert cur._request._client_session.prepared_statements == {}
15 changes: 0 additions & 15 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,21 +881,6 @@ def __call__(self, *args, **kwargs):
return http_response


def test_trino_result_response_headers():
"""
Validates that the `TrinoResult.response_headers` property returns the
headers associated to the TrinoQuery instance provided to the `TrinoResult`
class.
"""
mock_trino_query = mock.Mock(respone_headers={
'X-Trino-Fake-1': 'one',
'X-Trino-Fake-2': 'two',
})

result = TrinoResult(query=mock_trino_query, rows=[])
assert result.response_headers == mock_trino_query.response_headers


def test_trino_query_response_headers(sample_get_response_data):
"""
Validates that the `TrinoQuery.execute` function can take addtional headers
Expand Down
49 changes: 38 additions & 11 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
self._extra_credential = extra_credential
self._client_tags = client_tags
self._role = role
self._prepared_statements: Dict[str, str] = {}
self._object_lock = threading.Lock()

@property
Expand Down Expand Up @@ -206,6 +207,15 @@ def role(self, role):
with self._object_lock:
self._role = role

@property
def prepared_statements(self):
return self._prepared_statements

@prepared_statements.setter
def prepared_statements(self, prepared_statements):
with self._object_lock:
self._prepared_statements = prepared_statements


def get_header_values(headers, header):
return [val.strip() for val in headers[header].split(",")]
Expand All @@ -219,6 +229,14 @@ def get_session_property_values(headers, header):
]


def get_prepared_statement_values(headers, header):
kvs = get_header_values(headers, header)
return [
(k.strip(), urllib.parse.unquote_plus(v.strip()))
for k, v in (kv.split("=", 1) for kv in kvs)
]


class TrinoStatus(object):
def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None):
self.id = id
Expand Down Expand Up @@ -392,6 +410,13 @@ def http_headers(self) -> Dict[str, str]:
for name, value in self._client_session.properties.items()
)

if len(self._client_session.prepared_statements) != 0:
# ``name`` must not contain ``=``
headers[constants.HEADER_PREPARED_STATEMENT] = ",".join(
"{}={}".format(name, urllib.parse.quote_plus(statement))
for name, statement in self._client_session.prepared_statements.items()
)

# merge custom http headers
for key in self._client_session.headers:
if key in headers.keys():
Expand Down Expand Up @@ -556,6 +581,18 @@ def process(self, http_response) -> TrinoStatus:
if constants.HEADER_SET_ROLE in http_response.headers:
self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE]

if constants.HEADER_ADDED_PREPARE in http_response.headers:
for name, statement in get_prepared_statement_values(
http_response.headers, constants.HEADER_ADDED_PREPARE
):
self._client_session.prepared_statements[name] = statement

if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers:
for name in get_header_values(
http_response.headers, constants.HEADER_DEALLOCATED_PREPARE
):
self._client_session.prepared_statements.pop(name)

self._next_uri = response.get("nextUri")

return TrinoStatus(
Expand Down Expand Up @@ -622,10 +659,6 @@ def __iter__(self):

self._rows = next_rows

@property
def response_headers(self):
return self._query.response_headers


class TrinoQuery(object):
"""Represent the execution of a SQL statement by Trino."""
Expand All @@ -648,7 +681,6 @@ def __init__(
self._update_type = None
self._sql = sql
self._result: Optional[TrinoResult] = None
self._response_headers = None
self._experimental_python_types = experimental_python_types
self._row_mapper: Optional[RowMapper] = None

Expand Down Expand Up @@ -705,7 +737,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
self._result = TrinoResult(self, rows)

# Execute should block until at least one row is received
# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
return self._result
Expand All @@ -725,7 +757,6 @@ def fetch(self) -> List[List[Any]]:
status = self._request.process(response)
self._update_state(status)
logger.debug(status)
self._response_headers = response.headers
if status.next_uri is None:
self._finished = True

Expand Down Expand Up @@ -763,10 +794,6 @@ def finished(self) -> bool:
def cancelled(self) -> bool:
return self._cancelled

@property
def response_headers(self):
return self._response_headers


def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
def wrapper(func):
Expand Down
85 changes: 17 additions & 68 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from decimal import Decimal
from typing import Any, List, Optional # NOQA for mypy types

import copy
import uuid
import datetime
import math
Expand Down Expand Up @@ -301,52 +300,25 @@ def setinputsizes(self, sizes):
def setoutputsize(self, size, column):
raise trino.exceptions.NotSupportedError

def _prepare_statement(self, operation, statement_name):
def _prepare_statement(self, statement: str, name: str) -> None:
"""
Prepends the given `operation` with "PREPARE <statement_name> FROM" and
executes as a prepare statement.
Registers a prepared statement for the provided `operation` with the
`name` assigned to it.
:param operation: sql to be executed.
:param statement_name: name that will be assigned to the prepare
statement.
:raises trino.exceptions.FailedToObtainAddedPrepareHeader: Error raised
when unable to find the 'X-Trino-Added-Prepare' for the PREPARE
statement request.
:return: string representing the value of the 'X-Trino-Added-Prepare'
header.
:param statement: sql to be executed.
:param name: name that will be assigned to the prepared statement.
"""
sql = 'PREPARE {statement_name} FROM {operation}'.format(
statement_name=statement_name,
operation=operation
)

# Send prepare statement. Copy the _request object to avoid polluting the
# one that is going to be used to execute the actual operation.
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
sql = f"PREPARE {name} FROM {statement}"
query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql,
experimental_python_types=self._experimental_pyton_types)
result = query.execute()
query.execute()

# Iterate until the 'X-Trino-Added-Prepare' header is found or
# until there are no more results
for _ in result:
response_headers = result.response_headers

if constants.HEADER_ADDED_PREPARE in response_headers:
return response_headers[constants.HEADER_ADDED_PREPARE]

raise trino.exceptions.FailedToObtainAddedPrepareHeader

def _get_added_prepare_statement_trino_query(
def _execute_prepared_statement(
self,
statement_name,
params
):
sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params))

# No need to deepcopy _request here because this is the actual request
# operation
return trino.client.TrinoQuery(self._request, sql=sql, experimental_python_types=self._experimental_pyton_types)

def _format_prepared_param(self, param):
Expand Down Expand Up @@ -422,28 +394,11 @@ def _format_prepared_param(self, param):

raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))

def _deallocate_prepare_statement(self, added_prepare_header, statement_name):
def _deallocate_prepared_statement(self, statement_name: str) -> None:
sql = 'DEALLOCATE PREPARE ' + statement_name

# Send deallocate statement. Copy the _request object to avoid poluting the
# one that is going to be used to execute the actual operation.
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql,
experimental_python_types=self._experimental_pyton_types)
result = query.execute(
additional_http_headers={
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
}
)

# Iterate until the 'X-Trino-Deallocated-Prepare' header is found or
# until there are no more results
for _ in result:
response_headers = result.response_headers

if constants.HEADER_DEALLOCATED_PREPARE in response_headers:
return response_headers[constants.HEADER_DEALLOCATED_PREPARE]

raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader
query.execute()

def _generate_unique_statement_name(self):
return 'st_' + uuid.uuid4().hex.replace('-', '')
Expand All @@ -456,27 +411,21 @@ def execute(self, operation, params=None):
)

statement_name = self._generate_unique_statement_name()
# Send prepare statement
added_prepare_header = self._prepare_statement(
operation, statement_name
)
self._prepare_statement(operation, statement_name)

try:
# Send execute statement and assign the return value to `results`
# as it will be returned by the function
self._query = self._get_added_prepare_statement_trino_query(
self._query = self._execute_prepared_statement(
statement_name, params
)
result = self._query.execute(
additional_http_headers={
constants.HEADER_PREPARED_STATEMENT: added_prepare_header
}
)
result = self._query.execute()
finally:
# Send deallocate statement
# At this point the query can be deallocated since it has already
# been executed
self._deallocate_prepare_statement(added_prepare_header, statement_name)
# TODO: Consider caching prepared statements if requested by caller
self._deallocate_prepared_statement(statement_name)

else:
self._query = trino.client.TrinoQuery(self._request, sql=operation,
Expand Down
16 changes: 0 additions & 16 deletions trino/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,22 +134,6 @@ class TrinoUserError(TrinoQueryError, ProgrammingError):
pass


class FailedToObtainAddedPrepareHeader(Error):
"""
Raise this exception when unable to find the 'X-Trino-Added-Prepare'
header in the response of a PREPARE statement request.
"""
pass


class FailedToObtainDeallocatedPrepareHeader(Error):
"""
Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare'
header in the response of a DEALLOCATED statement request.
"""
pass


# client module errors
class HttpError(Exception):
pass
Expand Down

0 comments on commit d5f779b

Please sign in to comment.