Skip to content

Commit

Permalink
Acknowledge reception of data in TrinoResult
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Sep 1, 2022
1 parent cffd2b2 commit 9d898a8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
22 changes: 10 additions & 12 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,28 +594,26 @@ class TrinoResult(object):

def __init__(self, query, rows=None):
self._query = query
self._rows = rows or []
# Initial rows from the first POST request
self._rows = rows
self._rownumber = 0

@property
def rownumber(self) -> int:
return self._rownumber

def __iter__(self):
# Initial fetch from the first POST request
for row in self._rows:
self._rownumber += 1
yield row
self._rows = None

# Subsequent fetches from GET requests until next_uri is empty.
while not self._query.finished:
rows = self._query.fetch()
for row in rows:
# A query only transitions to a FINISHED state when the results are fully consumed:
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
while not self._query.finished or self._rows is not None:
next_rows = self._query.fetch() if not self._query.finished else None
for row in self._rows:
self._rownumber += 1
logger.debug("row %s", row)
yield row

self._rows = next_rows

@property
def response_headers(self):
return self._query.response_headers
Expand All @@ -641,7 +639,7 @@ def __init__(
self._request = request
self._update_type = None
self._sql = sql
self._result = TrinoResult(self)
self._result: Optional[TrinoResult] = None
self._response_headers = None
self._experimental_python_types = experimental_python_types
self._row_mapper: Optional[RowMapper] = None
Expand Down
2 changes: 1 addition & 1 deletion trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _prepare_statement(self, operation, statement_name):
operation=operation
)

# Send prepare statement. Copy the _request object to avoid poluting the
# Send prepare statement. Copy the _request object to avoid polluting the
# one that is going to be used to execute the actual operation.
query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql,
experimental_python_types=self._experimental_pyton_types)
Expand Down
6 changes: 3 additions & 3 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
"""
).strip()
res = connection.execute(sql.text(query), schema=schema, view=view_name)
return res.scalar()
return res.scalar_one_or_none()

def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
if not self.has_table(connection, table_name, schema):
Expand Down Expand Up @@ -284,7 +284,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str
sql.text(query),
catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
)
return dict(text=res.scalar())
return dict(text=res.scalar_one_or_none())
except error.TrinoQueryError as e:
if e.error_name in (
error.PERMISSION_DENIED,
Expand Down Expand Up @@ -326,7 +326,7 @@ def _get_server_version_info(self, connection: Connection) -> Any:
query = "SELECT version()"
try:
res = connection.execute(sql.text(query))
version = res.scalar()
version = res.scalar_one()
return tuple([version])
except exc.ProgrammingError as e:
logger.debug(f"Failed to get server version: {e.orig.message}")
Expand Down

0 comments on commit 9d898a8

Please sign in to comment.