diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 76260612a435a..b195d440c7944 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -60,6 +60,11 @@ def return_single_query_results(sql: str | Iterable[str], return_last: bool, spl def fetch_all_handler(cursor) -> list[tuple] | None: """Handler for DbApiHook.run() to return results""" + if not hasattr(cursor, "description"): + raise RuntimeError( + "The database we interact with does not support DBAPI 2.0. Use operator and " + "handlers that are specifically designed for your database." + ) if cursor.description is not None: return cursor.fetchall() else: @@ -67,7 +72,12 @@ def fetch_all_handler(cursor) -> list[tuple] | None: def fetch_one_handler(cursor) -> list[tuple] | None: - """Handler for DbApiHook.run() to return results""" + """Handler for DbApiHook.run() to return first result""" + if not hasattr(cursor, "description"): + raise RuntimeError( + "The database we interact with does not support DBAPI 2.0. Use operator and " + "handlers that are specifically designed for your database." + ) if cursor.description is not None: return cursor.fetchone() else: diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py index 9a89cf5fa9cb3..9530f88dab2ac 100644 --- a/airflow/providers/exasol/hooks/exasol.py +++ b/airflow/providers/exasol/hooks/exasol.py @@ -18,11 +18,11 @@ from __future__ import annotations from contextlib import closing -from typing import Any, Callable, Iterable, Mapping +from typing import Any, Callable, Iterable, Mapping, Sequence import pandas as pd import pyexasol -from pyexasol import ExaConnection +from pyexasol import ExaConnection, ExaStatement from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results @@ -133,6 +133,29 @@ def export_to_file( ) self.log.info("Data saved to %s", filename) + @staticmethod + def get_description(statement: ExaStatement) -> Sequence[Sequence]: + """ + Copied implementation from DB2-API wrapper. + More info https://github.com/exasol/pyexasol/blob/master/docs/DBAPI_COMPAT.md#db-api-20-wrapper + :param statement: Exasol statement + :return: description sequence of t + """ + cols = [] + for k, v in statement.columns().items(): + cols.append( + ( + k, + v.get("type", None), + v.get("size", None), + v.get("size", None), + v.get("precision", None), + v.get("scale", None), + True, + ) + ) + return cols + def run( self, sql: str | Iterable[str], @@ -176,18 +199,17 @@ def run( self.set_autocommit(conn, autocommit) results = [] for sql_statement in sql_list: - with closing(conn.execute(sql_statement, parameters)) as cur: + with closing(conn.execute(sql_statement, parameters)) as exa_statement: self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) if handler is not None: - result = handler(cur) + result = handler(exa_statement) if return_single_query_results(sql, return_last, split_statements): _last_result = result - _last_description = cur.description + _last_columns = self.get_description(exa_statement) else: results.append(result) - self.descriptions.append(cur.description) - - self.log.info("Rows affected: %s", cur.rowcount) + self.descriptions.append(self.get_description(exa_statement)) + self.log.info("Rows affected: %s", exa_statement.rowcount) # If autocommit was set to False or db does not support autocommit, we do a manual commit. if not self.get_autocommit(conn): @@ -196,7 +218,7 @@ def run( if handler is None: return None if return_single_query_results(sql, return_last, split_statements): - self.descriptions = [_last_description] + self.descriptions = [_last_columns] return _last_result else: return results @@ -241,3 +263,17 @@ def _serialize_cell(cell, conn=None) -> Any: :return: The cell """ return cell + + +def exasol_fetch_all_handler(statement: ExaStatement) -> list[tuple] | None: + if statement.result_type == "resultSet": + return statement.fetchall() + else: + return None + + +def exasol_fetch_one_handler(statement: ExaStatement) -> list[tuple] | None: + if statement.result_type == "resultSet": + return statement.fetchone() + else: + return None diff --git a/airflow/providers/exasol/operators/exasol.py b/airflow/providers/exasol/operators/exasol.py index 253e443b8ee6c..e1e6e7dc97ad9 100644 --- a/airflow/providers/exasol/operators/exasol.py +++ b/airflow/providers/exasol/operators/exasol.py @@ -17,10 +17,10 @@ # under the License. from __future__ import annotations -import warnings from typing import Sequence from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler class ExasolOperator(SQLExecuteQueryOperator): @@ -35,6 +35,7 @@ class ExasolOperator(SQLExecuteQueryOperator): (default value: False) :param parameters: (optional) the parameters to render the SQL query with. :param schema: (optional) name of the schema which overwrite defined one in connection + :param handler: (optional) handler to process the results of the query """ template_fields: Sequence[str] = ("sql",) @@ -43,16 +44,14 @@ class ExasolOperator(SQLExecuteQueryOperator): ui_color = "#ededed" def __init__( - self, *, exasol_conn_id: str = "exasol_default", schema: str | None = None, **kwargs + self, + *, + exasol_conn_id: str = "exasol_default", + schema: str | None = None, + handler=exasol_fetch_all_handler, + **kwargs, ) -> None: if schema is not None: hook_params = kwargs.pop("hook_params", {}) kwargs["hook_params"] = {"schema": schema, **hook_params} - - super().__init__(conn_id=exasol_conn_id, **kwargs) - warnings.warn( - """This class is deprecated. - Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""", - DeprecationWarning, - stacklevel=2, - ) + super().__init__(conn_id=exasol_conn_id, handler=handler, **kwargs) diff --git a/tests/providers/exasol/hooks/test_sql.py b/tests/providers/exasol/hooks/test_sql.py index 17bb33db907a6..465edc61314b6 100644 --- a/tests/providers/exasol/hooks/test_sql.py +++ b/tests/providers/exasol/hooks/test_sql.py @@ -18,6 +18,7 @@ # from __future__ import annotations +from typing import Any from unittest import mock from unittest.mock import MagicMock, patch @@ -57,8 +58,11 @@ def exasol_hook(): return ExasolHook() -def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: - return [(field,) for field in fields] +def get_columns(fields: list[str]) -> dict[str, dict[str, Any]]: + return { + field: {"type": "VARCHAR", "nullable": True, "precision": None, "scale": None, "length": None} + for field in fields + } index = 0 @@ -75,7 +79,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[1, 2], [11, 12]], id="The return_last set and no split statements set on single query in string", ), @@ -86,7 +95,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[1, 2], [11, 12]], id="The return_last not set and no split statements set on single query in string", ), @@ -97,7 +111,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[1, 2], [11, 12]], id="The return_last set and split statements set on single query in string", ), @@ -108,7 +127,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[[1, 2], [11, 12]]], id="The return_last not set and split statements set on single query in string", ), @@ -119,7 +143,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;", "select * from test.test2;"], [["id", "value"], ["id2", "value2"]], ([[1, 2], [11, 12]], [[3, 4], [13, 14]]), - [[("id2",), ("value2",)]], + [ + [ + ("id2", "VARCHAR", None, None, None, None, True), + ("value2", "VARCHAR", None, None, None, None, True), + ] + ], [[3, 4], [13, 14]], id="The return_last set and split statements set on multiple queries in string", ), # Failing @@ -130,7 +159,16 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test;", "select * from test.test2;"], [["id", "value"], ["id2", "value2"]], ([[1, 2], [11, 12]], [[3, 4], [13, 14]]), - [[("id",), ("value",)], [("id2",), ("value2",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ], + [ + ("id2", "VARCHAR", None, None, None, None, True), + ("value2", "VARCHAR", None, None, None, None, True), + ], + ], [[[1, 2], [11, 12]], [[3, 4], [13, 14]]], id="The return_last not set and split statements set on multiple queries in string", ), @@ -141,7 +179,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[[1, 2], [11, 12]]], id="The return_last set on single query in list", ), @@ -152,7 +195,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test"], [["id", "value"]], ([[1, 2], [11, 12]],), - [[("id",), ("value",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ] + ], [[[1, 2], [11, 12]]], id="The return_last not set on single query in list", ), @@ -163,7 +211,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test", "select * from test.test2"], [["id", "value"], ["id2", "value2"]], ([[1, 2], [11, 12]], [[3, 4], [13, 14]]), - [[("id2",), ("value2",)]], + [ + [ + ("id2", "VARCHAR", None, None, None, None, True), + ("value2", "VARCHAR", None, None, None, None, True), + ] + ], [[3, 4], [13, 14]], id="The return_last set set on multiple queries in list", ), @@ -174,7 +227,16 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]: ["select * from test.test", "select * from test.test2"], [["id", "value"], ["id2", "value2"]], ([[1, 2], [11, 12]], [[3, 4], [13, 14]]), - [[("id",), ("value",)], [("id2",), ("value2",)]], + [ + [ + ("id", "VARCHAR", None, None, None, None, True), + ("value", "VARCHAR", None, None, None, None, True), + ], + [ + ("id2", "VARCHAR", None, None, None, None, True), + ("value2", "VARCHAR", None, None, None, None, True), + ], + ], [[[1, 2], [11, 12]], [[3, 4], [13, 14]]], id="The return_last not set on multiple queries not set", ), @@ -196,8 +258,8 @@ def test_query( for index in range(len(cursor_descriptions)): cur = mock.MagicMock( rowcount=len(cursor_results[index]), - description=get_cursor_descriptions(cursor_descriptions[index]), ) + cur.columns.return_value = get_columns(cursor_descriptions[index]) cur.fetchall.return_value = cursor_results[index] cursors.append(cur) mock_conn.execute.side_effect = cursors diff --git a/tests/providers/exasol/operators/test_exasol.py b/tests/providers/exasol/operators/test_exasol.py index d893177f21c5d..c805684aa97c0 100644 --- a/tests/providers/exasol/operators/test_exasol.py +++ b/tests/providers/exasol/operators/test_exasol.py @@ -19,7 +19,7 @@ from unittest import mock -from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler from airflow.providers.exasol.operators.exasol import ExasolOperator @@ -32,7 +32,7 @@ def test_overwrite_autocommit(self, mock_get_db_hook): sql="SELECT 1", autocommit=True, parameters=None, - handler=fetch_all_handler, + handler=exasol_fetch_all_handler, return_last=True, ) @@ -44,7 +44,7 @@ def test_pass_parameters(self, mock_get_db_hook): sql="SELECT {value!s}", autocommit=False, parameters={"value": 1}, - handler=fetch_all_handler, + handler=exasol_fetch_all_handler, return_last=True, ) diff --git a/tests/providers/exasol/operators/test_exasol_sql.py b/tests/providers/exasol/operators/test_exasol_sql.py index 1c2651373b2cf..9ff999eed3a4e 100644 --- a/tests/providers/exasol/operators/test_exasol_sql.py +++ b/tests/providers/exasol/operators/test_exasol_sql.py @@ -22,7 +22,7 @@ import pytest -from airflow.providers.common.sql.hooks.sql import fetch_all_handler +from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler from airflow.providers.exasol.operators.exasol import ExasolOperator DATE = "2017-04-20" @@ -143,7 +143,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc dbapi_hook.run.assert_called_once_with( sql=sql, parameters=None, - handler=fetch_all_handler, + handler=exasol_fetch_all_handler, autocommit=False, return_last=return_last, split_statements=split_statement,