Skip to content

Commit

Permalink
Add lazy evaluation of server_version_info
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed May 8, 2023
1 parent 743f24a commit 19ad6fd
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
19 changes: 19 additions & 0 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,22 @@ def test_get_view_names_raises(trino_connection):

with pytest.raises(sqla.exc.NoSuchTableError):
sqla.inspect(engine).get_view_names(None)


@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
def test_version_is_lazy(trino_connection):
_, conn = trino_connection
result = conn.execute("SELECT 1")
result.fetchall()
num_queries = num_queries_containing_string(conn, "SELECT version()")
assert num_queries == 0
version_info = conn.dialect.server_version_info
assert isinstance(version_info, tuple)
num_queries = num_queries_containing_string(conn, "SELECT version()")
assert num_queries == 1


def num_queries_containing_string(connection, query_string, num=3):
result = connection.execute("select query from system.runtime.queries order by query_id desc limit ?", num)
rows = result.fetchall()
return len(list(filter(lambda rec: query_string in rec[0], rows)))
23 changes: 14 additions & 9 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,20 @@ def has_sequence(self, connection: Connection, sequence_name: str, schema: str =
"""Trino has no support for sequence. Returns False indicate that given sequence does not exists."""
return False

def _get_server_version_info(self, connection: Connection) -> Any:
query = "SELECT version()"
try:
res = connection.execute(sql.text(query))
version = res.scalar()
return tuple([version])
except exc.ProgrammingError as e:
logger.debug(f"Failed to get server version: {e.orig.message}")
return None
@classmethod
def _get_server_version_info(cls, connection: Connection) -> Any:
def get_server_version_info(_):
query = "SELECT version()"
try:
res = connection.execute(sql.text(query))
version = res.scalar()
return tuple([version])
except exc.ProgrammingError as e:
logger.debug(f"Failed to get server version: {e.orig.message}")
return None

# We make the server_version_info be evaluated lazily
cls.server_version_info = property(get_server_version_info, lambda instance, value: None)

def _raw_connection(self, connection: Union[Engine, Connection]) -> trino_dbapi.Connection:
if isinstance(connection, Engine):
Expand Down

0 comments on commit 19ad6fd

Please sign in to comment.