From ea306c9462615d6b215d43f7f17d68f4c62951b1 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Thu, 24 Nov 2022 10:38:29 +0100 Subject: [PATCH] Fix errors in Databricks SQL operator introduced when refactoring (#27854) When SQLExecuteQueryOperator has been introduced in #25717, it introduced some errors in the Databricks SQL operator: * The schema (description) parameter has been passed as _process_output parameter from Hook's output * The run() method of DatabricksHook was not conforming to other run methods of the Hook - it was returning Tuple of the result/description * The _process_output type was not specified - if scalar was used it returned different output than without it and it was not specified in the DBApiHook. This PR fixes it by: * the Databricks Hook is now conformant to the other DBAPIHooks in terms of value returned by Hook (backwards incompatible so we need to bump major version of the provider) * the DBApiHook now has "last_description" field which on one hand makes it stateless, on the other, the state reflects the description of the last run method and is not a problem to keep. This implies 1.4 version of common-sql provider as this is a new feature for the provider * the DBApiHook now has "scalar_return_last" field that indicates if scalar output was specified. * Python dbapi's "description" is properly named now - previously it was "schema" which clashed with the "schema" name passed to hook initialisation - the actual database schema --- airflow/providers/common/sql/hooks/sql.py | 11 +++--- airflow/providers/common/sql/operators/sql.py | 36 +++++++++++++++---- .../databricks/hooks/databricks_sql.py | 10 +++--- .../databricks/operators/databricks_sql.py | 22 ++++++++---- airflow/providers/exasol/hooks/exasol.py | 4 +-- .../providers/snowflake/hooks/snowflake.py | 4 +-- .../databricks/hooks/test_databricks_sql.py | 11 +++--- .../operators/test_databricks_sql.py | 16 +++++---- 8 files changed, 78 insertions(+), 36 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 1c67350c4dfc1..df808430fd9aa 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -18,7 +18,7 @@ from contextlib import closing from datetime import datetime -from typing import Any, Callable, Iterable, Mapping, cast +from typing import Any, Callable, Iterable, Mapping, Sequence, cast import sqlparse from packaging.version import Version @@ -111,9 +111,11 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa # We should not make schema available in deriving hooks for backwards compatibility # If a hook deriving from DBApiHook has a need to access schema, then it should retrieve it # from kwargs and store it on its own. We do not run "pop" here as we want to give the - # Hook deriving from the DBApiHook to still have access to the field in it's constructor + # Hook deriving from the DBApiHook to still have access to the field in its constructor self.__schema = schema self.log_sql = log_sql + self.scalar_return_last = False + self.last_description: Sequence[Sequence] | None = None def get_conn(self): """Returns a connection object""" @@ -244,7 +246,7 @@ def run( :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the ALL SQL expressions if handler was provided. """ - scalar_return_last = isinstance(sql, str) and return_last + self.scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): if split_statements: sql = self.split_sql_string(sql) @@ -268,6 +270,7 @@ def run( if handler is not None: result = handler(cur) results.append(result) + self.last_description = cur.description # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): @@ -275,7 +278,7 @@ def run( if handler is None: return None - elif scalar_return_last: + elif self.scalar_return_last: return results[-1] else: return results diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py index ffd3a6fcf9cbe..314af43003488 100644 --- a/airflow/providers/common/sql/operators/sql.py +++ b/airflow/providers/common/sql/operators/sql.py @@ -19,13 +19,14 @@ import ast import re -from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, SupportsAbs +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, SupportsAbs, overload from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException, AirflowFailException from airflow.hooks.base import BaseHook from airflow.models import BaseOperator, SkipMixin from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler +from airflow.typing_compat import Literal if TYPE_CHECKING: from airflow.utils.context import Context @@ -224,6 +225,33 @@ def __init__( self.split_statements = split_statements self.return_last = return_last + @overload + def _process_output( + self, results: Any, description: Sequence[Sequence] | None, scalar_results: Literal[True] + ) -> Any: + pass + + @overload + def _process_output( + self, results: list[Any], description: Sequence[Sequence] | None, scalar_results: Literal[False] + ) -> Any: + pass + + def _process_output( + self, results: Any | list[Any], description: Sequence[Sequence] | None, scalar_results: bool + ) -> Any: + """ + Can be overridden by the subclass in case some extra processing is needed. + The "process_output" method can override the returned output - augmenting or processing the + output as needed - the output returned will be returned as execute return value and if + do_xcom_push is set to True, it will be set as XCom returned + + :param results: results in the form of list of rows. + :param description: as returned by ``cur.description`` in the Python DBAPI + :param scalar_results: True if result is single scalar value rather than list of rows + """ + return results + def execute(self, context): self.log.info("Executing: %s", self.sql) hook = self.get_db_hook() @@ -244,11 +272,7 @@ def execute(self, context): split_statements=self.split_statements, ) - if hasattr(self, "_process_output"): - for out in output: - self._process_output(*out) - - return output + return self._process_output(output, hook.last_description, hook.scalar_return_last) def prepare_template(self) -> None: """Parse template file for attribute parameters.""" diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py index 5e456a9ca5f3f..f042435943889 100644 --- a/airflow/providers/databricks/hooks/databricks_sql.py +++ b/airflow/providers/databricks/hooks/databricks_sql.py @@ -147,7 +147,7 @@ def run( handler: Callable | None = None, split_statements: bool = True, return_last: bool = True, - ) -> tuple[str, Any] | list[tuple[str, Any]] | None: + ) -> Any | list[Any] | None: """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute @@ -163,7 +163,7 @@ def run( :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the LAST SQL expression if handler was provided. """ - scalar_return_last = isinstance(sql, str) and return_last + self.scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): if split_statements: sql = self.split_sql_string(sql) @@ -186,14 +186,14 @@ def run( if handler is not None: result = handler(cur) - schema = cur.description - results.append((schema, result)) + results.append(result) + self.last_description = cur.description self._sql_conn = None if handler is None: return None - elif scalar_return_last: + elif self.scalar_return_last: return results[-1] else: return results diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py index d808377cb1ae2..379b0fd2c9bbb 100644 --- a/airflow/providers/databricks/operators/databricks_sql.py +++ b/airflow/providers/databricks/operators/databricks_sql.py @@ -120,12 +120,21 @@ def get_db_hook(self) -> DatabricksSqlHook: } return DatabricksSqlHook(self.databricks_conn_id, **hook_params) - def _process_output(self, schema, results): + def _process_output( + self, results: Any | list[Any], description: Sequence[Sequence] | None, scalar_results: bool + ) -> Any: if not self._output_path: - return + return description, results if not self._output_format: raise AirflowException("Output format should be specified!") - field_names = [field[0] for field in schema] + if description is None: + self.log.warning("Description of the cursor is missing. Will not process the output") + return description, results + field_names = [field[0] for field in description] + if scalar_results: + list_results: list[Any] = [results] + else: + list_results = results if self._output_format.lower() == "csv": with open(self._output_path, "w", newline="") as file: if self._csv_params: @@ -138,18 +147,19 @@ def _process_output(self, schema, results): writer = csv.DictWriter(file, fieldnames=field_names, **csv_params) if write_header: writer.writeheader() - for row in results: + for row in list_results: writer.writerow(row.asDict()) elif self._output_format.lower() == "json": with open(self._output_path, "w") as file: - file.write(json.dumps([row.asDict() for row in results])) + file.write(json.dumps([row.asDict() for row in list_results])) elif self._output_format.lower() == "jsonl": with open(self._output_path, "w") as file: - for row in results: + for row in list_results: file.write(json.dumps(row.asDict())) file.write("\n") else: raise AirflowException(f"Unsupported output format: '{self._output_format}'") + return description, results COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT", "BINARYFILE"] diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 3b45e8f2f2eec..49289df37a314 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -157,7 +157,7 @@ def run( :param return_last: Whether to return result for only last statement or for all after split :return: return only result of the LAST SQL expression if handler was provided. """ - scalar_return_last = isinstance(sql, str) and return_last + self.scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): if split_statements: sql = self.split_sql_string(sql) @@ -187,7 +187,7 @@ def run( if handler is None: return None - elif scalar_return_last: + elif self.scalar_return_last: return results[-1] else: return results diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py index 138025a455652..e525efe763439 100644 --- a/airflow/providers/snowflake/hooks/snowflake.py +++ b/airflow/providers/snowflake/hooks/snowflake.py @@ -350,7 +350,7 @@ def run( """ self.query_ids = [] - scalar_return_last = isinstance(sql, str) and return_last + self.scalar_return_last = isinstance(sql, str) and return_last if isinstance(sql, str): if split_statements: split_statements_tuple = util_text.split_statements(StringIO(sql)) @@ -387,7 +387,7 @@ def run( if handler is None: return None - elif scalar_return_last: + elif self.scalar_return_last: return results[-1] else: return results diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index bd52a64c98e40..7f1aaf493a4f0 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +# from __future__ import annotations import unittest @@ -70,17 +71,17 @@ def test_query(self, mock_requests, mock_conn): type(mock_requests.get.return_value).status_code = status_code_mock test_fields = ["id", "value"] - test_schema = [(field,) for field in test_fields] + test_description = [(field,) for field in test_fields] conn = mock_conn.return_value - cur = mock.MagicMock(rowcount=0, description=test_schema) + cur = mock.MagicMock(rowcount=0, description=test_description) cur.fetchall.return_value = [] conn.cursor.return_value = cur query = "select * from test.test;" - schema, results = self.hook.run(sql=query, handler=fetch_all_handler) + results = self.hook.run(sql=query, handler=fetch_all_handler) - assert schema == test_schema + assert self.hook.last_description == test_description assert results == [] cur.execute.assert_has_calls([mock.call(q) for q in [query]]) diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index 0064c0f7f6d5f..9a989dfae368a 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -47,13 +47,15 @@ def test_exec_success(self, db_mock_class): sql = "select * from dummy" op = DatabricksSqlOperator(task_id=TASK_ID, sql=sql, do_xcom_push=True) db_mock = db_mock_class.return_value - mock_schema = [("id",), ("value",)] + mock_description = [("id",), ("value",)] mock_results = [Row(id=1, value="value1")] - db_mock.run.return_value = [(mock_schema, mock_results)] + db_mock.run.return_value = mock_results + db_mock.last_description = mock_description + db_mock.scalar_return_last = False - results = op.execute(None) + execute_results = op.execute(None) - assert results[0][1] == mock_results + assert execute_results == (mock_description, mock_results) db_mock_class.assert_called_once_with( DEFAULT_CONN_ID, http_path=None, @@ -82,9 +84,11 @@ def test_exec_write_file(self, db_mock_class): tempfile_path = tempfile.mkstemp()[1] op = DatabricksSqlOperator(task_id=TASK_ID, sql=sql, output_path=tempfile_path) db_mock = db_mock_class.return_value - mock_schema = [("id",), ("value",)] + mock_description = [("id",), ("value",)] mock_results = [Row(id=1, value="value1")] - db_mock.run.return_value = [(mock_schema, mock_results)] + db_mock.run.return_value = mock_results + db_mock.last_description = mock_description + db_mock.scalar_return_last = False try: op.execute(None)