From f2ab384fab98c53b09cb643a3fb82525b3c50dc9 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Thu, 5 Jan 2023 13:34:26 +0100 Subject: [PATCH] Handle non-compliant behaviour of Exasol cursor Exasol is not DBAPI-2 compliant, as described in: https://github.com/exasol/pyexasol/blob/master/docs/DBAPI_COMPAT.md This means that SQLOperator cannot be used for it and we should undeprecate the original operator and use specific Exasol handler as default for the returned values. We also add explicit exception in case non-compliant DB2-API database (with missing description) is used with default Operator and handlers. Fixes: #28731 --- airflow/providers/common/sql/hooks/sql.py | 12 ++- airflow/providers/exasol/hooks/exasol.py | 54 ++++++++++-- airflow/providers/exasol/operators/exasol.py | 19 ++-- tests/providers/exasol/hooks/test_sql.py | 88 ++++++++++++++++--- .../providers/exasol/operators/test_exasol.py | 6 +- .../exasol/operators/test_exasol_sql.py | 4 +- 6 files changed, 145 insertions(+), 38 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 76260612a435a..b195d440c7944 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -60,6 +60,11 @@ def return_single_query_results(sql: str | Iterable[str], return_last: bool, spl def fetch_all_handler(cursor) -> list[tuple] | None: """Handler for DbApiHook.run() to return results""" + if not hasattr(cursor, "description"): + raise RuntimeError( + "The database we interact with does not support DBAPI 2.0. Use operator and " + "handlers that are specifically designed for your database." + ) if cursor.description is not None: return cursor.fetchall() else: @@ -67,7 +72,12 @@ def fetch_all_handler(cursor) -> list[tuple] | None: def fetch_one_handler(cursor) -> list[tuple] | None: - """Handler for DbApiHook.run() to return results""" + """Handler for DbApiHook.run() to return first result""" + if not hasattr(cursor, "description"): + raise RuntimeError( + "The database we interact with does not support DBAPI 2.0. Use operator and " + "handlers that are specifically designed for your database." + ) if cursor.description is not None: return cursor.fetchone() else: diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 9a89cf5fa9cb3..9530f88dab2ac 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -18,11 +18,11 @@ from __future__ import annotations from contextlib import closing -from typing import Any, Callable, Iterable, Mapping +from typing import Any, Callable, Iterable, Mapping, Sequence import pandas as pd import pyexasol -from pyexasol import ExaConnection +from pyexasol import ExaConnection, ExaStatement from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results @@ -133,6 +133,29 @@ def export_to_file( ) self.log.info("Data saved to %s", filename) + @staticmethod + def get_description(statement: ExaStatement) -> Sequence[Sequence]: + """ + Copied implementation from DB2-API wrapper. + More info https://github.com/exasol/pyexasol/blob/master/docs/DBAPI_COMPAT.md#db-api-20-wrapper + :param statement: Exasol statement + :return: description sequence of t + """ + cols = [] + for k, v in statement.columns().items(): + cols.append( + ( + k, + v.get("type", None), + v.get("size", None), + v.get("size", None), + v.get("precision", None), + v.get("scale", None), + True, + ) + ) + return cols + def run( self, sql: str | Iterable[str], @@ -176,18 +199,17 @@ def run( self.set_autocommit(conn, autocommit) results = [] for sql_statement in sql_list: - with closing(conn.execute(sql_statement, parameters)) as cur: + with closing(conn.execute(sql_statement, parameters)) as exa_statement: self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) if handler is not None: - result = handler(cur) + result = handler(exa_statement) if return_single_query_results(sql, return_last, split_statements): _last_result = result - _last_description = cur.description + _last_columns = self.get_description(exa_statement) else: results.append(result) - self.descriptions.append(cur.description) - - self.log.info("Rows affected: %s", cur.rowcount) + self.descriptions.append(self.get_description(exa_statement)) + self.log.info("Rows affected: %s", exa_statement.rowcount) # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): @@ -196,7 +218,7 @@ def run( if handler is None: return None if return_single_query_results(sql, return_last, split_statements): - self.descriptions = [_last_description] + self.descriptions = [_last_columns] return _last_result else: return results @@ -241,3 +263,17 @@ def _serialize_cell(cell, conn=None) -> Any: :return: The cell """ return cell + + +def exasol_fetch_all_handler(statement: ExaStatement) -> list[tuple] | None: + if statement.result_type == "resultSet": + return statement.fetchall() + else: + return None + + +def exasol_fetch_one_handler(statement: ExaStatement) -> list[tuple] | None: + if statement.result_type == "resultSet": + return statement.fetchone() + else: + return None diff --git a/airflow/providers/exasol/operators/exasol.py b/airflow/providers/exasol/operators/exasol.py index 253e443b8ee6c..e1e6e7dc97ad9 100644 --- a/airflow/providers/exasol/operators/exasol.py +++ b/airflow/providers/exasol/operators/exasol.py @@ -17,10 +17,10 @@ # under the License. from __future__ import annotations -import warnings from typing import Sequence from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler class ExasolOperator(SQLExecuteQueryOperator): @@ -35,6 +35,7 @@ class ExasolOperator(SQLExecuteQueryOperator): (default value: False) :param parameters: (optional) the parameters to render the SQL query with. :param schema: (optional) name of the schema which overwrite defined one in connection + :param handler: (optional) handler to process the results of the query """ template_fields: Sequence[str] = ("sql",) @@ -43,16 +44,14 @@ class ExasolOperator(SQLExecuteQueryOperator): ui_color = "#ededed" def __init__( - self, *, exasol_conn_id: str = "exasol_default", schema: str | None = None, **kwargs + self, + *, + exasol_conn_id: str = "exasol_default", + schema: str | None = None, + handler=exasol_fetch_all_handler, + **kwargs, ) -> None: if schema is not None: hook_params = kwargs.pop("hook_params", {}) kwargs["hook_params"] = {"schema": schema, **hook_params} - - super().__init__(conn_id=exasol_conn_id, **kwargs) - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", - DeprecationWarning, - stacklevel=2, - ) + super().__init__(conn_id=exasol_conn_id, handler=handler, **kwargs) diff --git a/tests/providers/exasol/hooks/test_sql.py b/tests/providers/exasol/hooks/test_sql.py index 17bb33db907a6..465edc61314b6 100644 --- a/tests/providers/exasol/hooks/test_sql.py +++ b/tests/providers/exasol/hooks/test_sql.py @@ -18,6 +18,7 @@ # from __future__ import annotations +from typing import Any from unittest import mock from unittest.mock import MagicMock, patch @@ -57,8 +58,11 @@ def exasol_hook(): return ExasolHook() -def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: - return [(field,) for field in fields] +def get_columns(fields: list[str]) -> dict[str, dict[str, Any]]: + return { + field: {"type": "VARCHAR", "nullable": True, "precision": None, "scale": None, "length": None} + for field in fields + } index = 0 @@ -75,7 +79,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[1, 2], [11, 12]], id="The return_last set and no split statements set on single query in string", ), @@ -86,7 +95,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[1, 2], [11, 12]], id="The return_last not set and no split statements set on single query in string", ), @@ -97,7 +111,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[1, 2], [11, 12]], id="The return_last set and split statements set on single query in string", ), @@ -108,7 +127,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[[1, 2], [11, 12]]], id="The return_last not set and split statements set on single query in string", ), @@ -119,7 +143,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;", "select * from test.test2;"], [["id", "value"], ["id2", "value2"]], ([[1, 2], [11, 12]], [[3, 4], [13, 14]]), - [[("id2",), ("value2",)]], + [ + [ + ("id2", "VARCHAR", None, None, None, None, True), + ("value2", "VARCHAR", None, None, None, None, True), + ] + ], [[3, 4], [13, 14]], id="The return_last set and split statements set on multiple queries in string", ), # Failing @@ -130,7 +159,16 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;", "select * from test.test2;"], [["id", "value"], ["id2", "value2"]], ([[1, 2], [11, 12]], [[3, 4], [13, 14]]), - [[("id",), ("value",)], [("id2",), ("value2",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ], + [ + ("id2", "VARCHAR", None, None, None, None, True), + ("value2", "VARCHAR", None, None, None, None, True), + ], + ], [[[1, 2], [11, 12]], [[3, 4], [13, 14]]], id="The return_last not set and split statements set on multiple queries in string", ), @@ -141,7 +179,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[[1, 2], [11, 12]]], id="The return_last set on single query in list", ), @@ -152,7 +195,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[[1, 2], [11, 12]]], id="The return_last not set on single query in list", ), @@ -163,7 +211,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test", "select * from test.test2"], [["id", "value"], ["id2", "value2"]], ([[1, 2], [11, 12]], [[3, 4], [13, 14]]), - [[("id2",), ("value2",)]], + [ + [ + ("id2", "VARCHAR", None, None, None, None, True), + ("value2", "VARCHAR", None, None, None, None, True), + ] + ], [[3, 4], [13, 14]], id="The return_last set set on multiple queries in list", ), @@ -174,7 +227,16 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test", "select * from test.test2"], [["id", "value"], ["id2", "value2"]], ([[1, 2], [11, 12]], [[3, 4], [13, 14]]), - [[("id",), ("value",)], [("id2",), ("value2",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ], + [ + ("id2", "VARCHAR", None, None, None, None, True), + ("value2", "VARCHAR", None, None, None, None, True), + ], + ], [[[1, 2], [11, 12]], [[3, 4], [13, 14]]], id="The return_last not set on multiple queries not set", ), @@ -196,8 +258,8 @@ def test_query( for index in range(len(cursor_descriptions)): cur = mock.MagicMock( rowcount=len(cursor_results[index]), - description=get_cursor_descriptions(cursor_descriptions[index]), ) + cur.columns.return_value = get_columns(cursor_descriptions[index]) cur.fetchall.return_value = cursor_results[index] cursors.append(cur) mock_conn.execute.side_effect = cursors diff --git a/tests/providers/exasol/operators/test_exasol.py b/tests/providers/exasol/operators/test_exasol.py index d893177f21c5d..c805684aa97c0 100644 --- a/tests/providers/exasol/operators/test_exasol.py +++ b/tests/providers/exasol/operators/test_exasol.py @@ -19,7 +19,7 @@ from unittest import mock -from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler from airflow.providers.exasol.operators.exasol import ExasolOperator @@ -32,7 +32,7 @@ def test_overwrite_autocommit(self, mock_get_db_hook): sql="SELECT 1", autocommit=True, parameters=None, - handler=fetch_all_handler, + handler=exasol_fetch_all_handler, return_last=True, ) @@ -44,7 +44,7 @@ def test_pass_parameters(self, mock_get_db_hook): sql="SELECT {value!s}", autocommit=False, parameters={"value": 1}, - handler=fetch_all_handler, + handler=exasol_fetch_all_handler, return_last=True, ) diff --git a/tests/providers/exasol/operators/test_exasol_sql.py b/tests/providers/exasol/operators/test_exasol_sql.py index 1c2651373b2cf..9ff999eed3a4e 100644 --- a/tests/providers/exasol/operators/test_exasol_sql.py +++ b/tests/providers/exasol/operators/test_exasol_sql.py @@ -22,7 +22,7 @@ import pytest -from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler from airflow.providers.exasol.operators.exasol import ExasolOperator DATE = "2017-04-20" @@ -143,7 +143,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc dbapi_hook.run.assert_called_once_with( sql=sql, parameters=None, - handler=fetch_all_handler, + handler=exasol_fetch_all_handler, autocommit=False, return_last=return_last, split_statements=split_statement,