From d437b6f402127e3d2c1e125b79ea7c98b806375d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 25 Aug 2022 17:31:54 -0700 Subject: [PATCH 1/2] Explicitly close cursors. --- dbt/adapters/databricks/connections.py | 127 ++++++++++++------ dbt/adapters/databricks/impl.py | 35 ++++- .../macros/materializations/seed.sql | 2 +- tests/integration/base.py | 2 +- 4 files changed, 122 insertions(+), 44 deletions(-) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 82661b9f9..fea88ca89 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -1,3 +1,4 @@ +import warnings from contextlib import contextmanager from dataclasses import dataclass import itertools @@ -22,7 +23,8 @@ import dbt.exceptions from dbt.adapters.base import Credentials from dbt.adapters.databricks import __version__ -from dbt.contracts.connection import Connection, ConnectionState +from dbt.clients import agate_helper +from dbt.contracts.connection import AdapterResponse, Connection, ConnectionState from dbt.events import AdapterLogger from dbt.events.functions import fire_event from dbt.events.types import ConnectionUsed, SQLQuery, SQLQueryStatus @@ -141,46 +143,53 @@ def _connection_keys(self, *, with_aliases: bool = False) -> Tuple[str, ...]: return tuple(connection_keys) -class DatabricksSQLConnectionWrapper(object): +class DatabricksSQLConnectionWrapper: """Wrap a Databricks SQL connector in a way that no-ops transactions""" _conn: DatabricksSQLConnection - _cursor: Optional[DatabricksSQLCursor] def __init__(self, conn: DatabricksSQLConnection): self._conn = conn - self._cursor = None - def cursor(self) -> "DatabricksSQLConnectionWrapper": - self._cursor = self._conn.cursor() - return self - - def cancel(self) -> None: - if self._cursor: - try: - self._cursor.cancel() - except DBSQLError as exc: - logger.debug("Exception while cancelling query: {}".format(exc)) - _log_dbsql_errors(exc) + def cursor(self) -> "DatabricksSQLCursorWrapper": + return DatabricksSQLCursorWrapper(self._conn.cursor()) def close(self) -> None: - if self._cursor: - try: - self._cursor.close() - except DBSQLError as exc: - logger.debug("Exception while closing cursor: {}".format(exc)) - _log_dbsql_errors(exc) self._conn.close() def rollback(self, *args: Any, **kwargs: Any) -> None: logger.debug("NotImplemented: rollback") + +class DatabricksSQLCursorWrapper: + """Wrap a Databricks SQL cursor in a way that no-ops transactions""" + + _cursor: DatabricksSQLCursor + + def __init__(self, cursor: DatabricksSQLCursor): + self._cursor = cursor + + def cancel(self) -> None: + try: + self._cursor.cancel() + except DBSQLError as exc: + logger.debug("Exception while cancelling query: {}".format(exc)) + _log_dbsql_errors(exc) + + def close(self) -> None: + try: + self._cursor.close() + except DBSQLError as exc: + logger.debug("Exception while closing cursor: {}".format(exc)) + _log_dbsql_errors(exc) + def fetchall(self) -> Sequence[Tuple]: - assert self._cursor is not None return self._cursor.fetchall() + def fetchone(self) -> Optional[Tuple]: + return self._cursor.fetchone() + def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None: - assert self._cursor is not None if sql.strip().endswith(";"): sql = sql.strip()[:-1] if bindings is not None: @@ -210,21 +219,23 @@ def description( Optional[bool], ] ]: - assert self._cursor is not None return self._cursor.description def schemas(self, catalog_name: str, schema_name: Optional[str] = None) -> None: - assert self._cursor is not None self._cursor.schemas(catalog_name=catalog_name, schema_name=schema_name) + def __del__(self) -> None: + if self._cursor.open: + # This should not happen. The cursor should explicitly be closed. + self._cursor.close() + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn("The cursor was closed by destructor.") + class DatabricksConnectionManager(SparkConnectionManager): TYPE: ClassVar[str] = "databricks" - DROP_JAVA_STACKTRACE_REGEX: ClassVar["re.Pattern[str]"] = re.compile( - r"(?<=Caused by: )(.+?)(?=^\t?at )", re.DOTALL | re.MULTILINE - ) - @contextmanager def exception_handler(self, sql: str) -> Iterator[None]: try: @@ -248,28 +259,62 @@ def exception_handler(self, sql: str) -> Iterator[None]: else: raise dbt.exceptions.RuntimeException(str(exc)) from exc + def add_query( + self, + sql: str, + auto_begin: bool = True, + bindings: Optional[Any] = None, + abridge_sql_log: bool = False, + *, + close_cursor: bool = False, + ) -> Tuple[Connection, Any]: + conn, cursor = super().add_query(sql, auto_begin, bindings, abridge_sql_log) + if close_cursor and hasattr(cursor, "close"): + cursor.close() + return conn, cursor + + def execute( + self, sql: str, auto_begin: bool = False, fetch: bool = False + ) -> Tuple[AdapterResponse, Table]: + sql = self._add_query_comment(sql) + _, cursor = self.add_query(sql, auto_begin) + try: + response = self.get_response(cursor) + if fetch: + table = self.get_result_from_cursor(cursor) + else: + table = agate_helper.empty_table() + return response, table + finally: + cursor.close() + def _execute_cursor( - self, log_sql: str, f: Callable[[DatabricksSQLConnectionWrapper], None] + self, log_sql: str, f: Callable[[DatabricksSQLCursorWrapper], None] ) -> Table: connection = self.get_thread_connection() fire_event(ConnectionUsed(conn_type=self.TYPE, conn_name=connection.name)) - with self.exception_handler(log_sql): - fire_event(SQLQuery(conn_name=connection.name, sql=log_sql)) - pre = time.time() + cursor: Optional[DatabricksSQLCursorWrapper] = None + try: + with self.exception_handler(log_sql): + fire_event(SQLQuery(conn_name=connection.name, sql=log_sql)) + pre = time.time() - handle: DatabricksSQLConnectionWrapper = connection.handle - cursor = handle.cursor() - f(cursor) + handle: DatabricksSQLConnectionWrapper = connection.handle + cursor = handle.cursor() + f(cursor) - fire_event( - SQLQueryStatus( - status=str(self.get_response(cursor)), elapsed=round((time.time() - pre), 2) + fire_event( + SQLQueryStatus( + status=str(self.get_response(cursor)), elapsed=round((time.time() - pre), 2) + ) ) - ) - return self.get_result_from_cursor(cursor) + return self.get_result_from_cursor(cursor) + finally: + if cursor is not None: + cursor.close() def list_schemas(self, database: str, schema: Optional[str] = None) -> Table: return self._execute_cursor( diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 490de8838..ae34235b9 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -16,7 +16,7 @@ LIST_RELATIONS_MACRO_NAME, LIST_SCHEMAS_MACRO_NAME, ) -from dbt.contracts.connection import AdapterResponse +from dbt.contracts.connection import AdapterResponse, Connection from dbt.contracts.graph.manifest import Manifest from dbt.contracts.relation import RelationType import dbt.exceptions @@ -211,6 +211,39 @@ def _get_columns_for_catalog(self, relation: DatabricksRelation) -> Iterable[Dic as_dict["column_type"] = as_dict.pop("dtype") yield as_dict + def add_query( + self, + sql: str, + auto_begin: bool = True, + bindings: Optional[Any] = None, + abridge_sql_log: bool = False, + *, + close_cursor: bool = False, + ) -> Tuple[Connection, Any]: + return self.connections.add_query( + sql, auto_begin, bindings, abridge_sql_log, close_cursor=close_cursor + ) + + def run_sql_for_tests( + self, sql: str, fetch: str, conn: Connection + ) -> Optional[Union[Optional[Tuple], List[Tuple]]]: + cursor = conn.handle.cursor() + try: + cursor.execute(sql) + if fetch == "one": + return cursor.fetchone() + elif fetch == "all": + return cursor.fetchall() + else: + return None + except BaseException as e: + print(sql) + print(e) + raise + finally: + cursor.close() + conn.transaction_open = False + def valid_incremental_strategies(self) -> List[str]: return ["append", "merge", "insert_overwrite"] diff --git a/dbt/include/databricks/macros/materializations/seed.sql b/dbt/include/databricks/macros/materializations/seed.sql index 8a8310c33..f554b2a55 100644 --- a/dbt/include/databricks/macros/materializations/seed.sql +++ b/dbt/include/databricks/macros/materializations/seed.sql @@ -29,7 +29,7 @@ {%- endfor %} {% endset %} - {% do adapter.add_query(sql, bindings=bindings, abridge_sql_log=True) %} + {% do adapter.add_query(sql, bindings=bindings, abridge_sql_log=True, close_cursor=True) %} {% if loop.index0 == 0 %} {% do statements.append(sql) %} diff --git a/tests/integration/base.py b/tests/integration/base.py index 507ea443d..7e1bafff4 100644 --- a/tests/integration/base.py +++ b/tests/integration/base.py @@ -478,7 +478,7 @@ def run_sql(self, query, fetch="None", kwargs=None, connection_name=None): try: cursor.execute(sql) if fetch == "one": - return cursor.fetchall()[0] + return cursor.fetchone() elif fetch == "all": return cursor.fetchall() else: From 02947ad475cf5ac8031c8d9b698ea0ba39b0f5fb Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 25 Aug 2022 17:35:45 -0700 Subject: [PATCH 2/2] changelog. --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6485742b..5bbf74fb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ - Apply "Initial refactoring of incremental materialization" ([#148](https://github.com/databricks/dbt-databricks/pull/148)) - Now dbt-databricks uses `adapter.get_incremental_strategy_macro` instead of `dbt_spark_get_incremental_sql` macro to dispatch the incremental strategy macro. The overwritten `dbt_spark_get_incremental_sql` macro will not work anymore. +## dbt-databricks 1.2.2 (Release TBD) + +### Under the hood +- Explicitly close cursors ([#163](https://github.com/databricks/dbt-databricks/pull/163)) + ## dbt-databricks 1.2.1 (August 24, 2022) ### Features