Skip to content

Commit

Permalink
Handle non-compliant behaviour of Exasol cursor
Browse files Browse the repository at this point in the history
Exasol is not DBAPI-2 compliant, as described in:
https://github.com/exasol/pyexasol/blob/master/docs/DBAPI_COMPAT.md

This means that SQLOperator cannot be used for it and we should
undeprecate the original operator and use specific Exasol handler
as default for the returned values.

We also add explicit exception in case non-compliant DB2-API
database (with missing description) is used with default Operator
and handlers.

Fixes: #28731
  • Loading branch information
potiuk committed Jan 8, 2023
1 parent ce67786 commit f2ab384
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 38 deletions.
12 changes: 11 additions & 1 deletion airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,24 @@ def return_single_query_results(sql: str | Iterable[str], return_last: bool, spl

def fetch_all_handler(cursor) -> list[tuple] | None:
"""Handler for DbApiHook.run() to return results"""
if not hasattr(cursor, "description"):
raise RuntimeError(
"The database we interact with does not support DBAPI 2.0. Use operator and "
"handlers that are specifically designed for your database."
)
if cursor.description is not None:
return cursor.fetchall()
else:
return None


def fetch_one_handler(cursor) -> list[tuple] | None:
"""Handler for DbApiHook.run() to return results"""
"""Handler for DbApiHook.run() to return first result"""
if not hasattr(cursor, "description"):
raise RuntimeError(
"The database we interact with does not support DBAPI 2.0. Use operator and "
"handlers that are specifically designed for your database."
)
if cursor.description is not None:
return cursor.fetchone()
else:
Expand Down
54 changes: 45 additions & 9 deletions airflow/providers/exasol/hooks/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from __future__ import annotations

from contextlib import closing
from typing import Any, Callable, Iterable, Mapping
from typing import Any, Callable, Iterable, Mapping, Sequence

import pandas as pd
import pyexasol
from pyexasol import ExaConnection
from pyexasol import ExaConnection, ExaStatement

from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results

Expand Down Expand Up @@ -133,6 +133,29 @@ def export_to_file(
)
self.log.info("Data saved to %s", filename)

@staticmethod
def get_description(statement: ExaStatement) -> Sequence[Sequence]:
"""
Copied implementation from DB2-API wrapper.
More info https://github.com/exasol/pyexasol/blob/master/docs/DBAPI_COMPAT.md#db-api-20-wrapper
:param statement: Exasol statement
:return: description sequence of t
"""
cols = []
for k, v in statement.columns().items():
cols.append(
(
k,
v.get("type", None),
v.get("size", None),
v.get("size", None),
v.get("precision", None),
v.get("scale", None),
True,
)
)
return cols

def run(
self,
sql: str | Iterable[str],
Expand Down Expand Up @@ -176,18 +199,17 @@ def run(
self.set_autocommit(conn, autocommit)
results = []
for sql_statement in sql_list:
with closing(conn.execute(sql_statement, parameters)) as cur:
with closing(conn.execute(sql_statement, parameters)) as exa_statement:
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if handler is not None:
result = handler(cur)
result = handler(exa_statement)
if return_single_query_results(sql, return_last, split_statements):
_last_result = result
_last_description = cur.description
_last_columns = self.get_description(exa_statement)
else:
results.append(result)
self.descriptions.append(cur.description)

self.log.info("Rows affected: %s", cur.rowcount)
self.descriptions.append(self.get_description(exa_statement))
self.log.info("Rows affected: %s", exa_statement.rowcount)

# If autocommit was set to False or db does not support autocommit, we do a manual commit.
if not self.get_autocommit(conn):
Expand All @@ -196,7 +218,7 @@ def run(
if handler is None:
return None
if return_single_query_results(sql, return_last, split_statements):
self.descriptions = [_last_description]
self.descriptions = [_last_columns]
return _last_result
else:
return results
Expand Down Expand Up @@ -241,3 +263,17 @@ def _serialize_cell(cell, conn=None) -> Any:
:return: The cell
"""
return cell


def exasol_fetch_all_handler(statement: ExaStatement) -> list[tuple] | None:
if statement.result_type == "resultSet":
return statement.fetchall()
else:
return None


def exasol_fetch_one_handler(statement: ExaStatement) -> list[tuple] | None:
if statement.result_type == "resultSet":
return statement.fetchone()
else:
return None
19 changes: 9 additions & 10 deletions airflow/providers/exasol/operators/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
# under the License.
from __future__ import annotations

import warnings
from typing import Sequence

from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler


class ExasolOperator(SQLExecuteQueryOperator):
Expand All @@ -35,6 +35,7 @@ class ExasolOperator(SQLExecuteQueryOperator):
(default value: False)
:param parameters: (optional) the parameters to render the SQL query with.
:param schema: (optional) name of the schema which overwrite defined one in connection
:param handler: (optional) handler to process the results of the query
"""

template_fields: Sequence[str] = ("sql",)
Expand All @@ -43,16 +44,14 @@ class ExasolOperator(SQLExecuteQueryOperator):
ui_color = "#ededed"

def __init__(
self, *, exasol_conn_id: str = "exasol_default", schema: str | None = None, **kwargs
self,
*,
exasol_conn_id: str = "exasol_default",
schema: str | None = None,
handler=exasol_fetch_all_handler,
**kwargs,
) -> None:
if schema is not None:
hook_params = kwargs.pop("hook_params", {})
kwargs["hook_params"] = {"schema": schema, **hook_params}

super().__init__(conn_id=exasol_conn_id, **kwargs)
warnings.warn(
"""This class is deprecated.
Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""",
DeprecationWarning,
stacklevel=2,
)
super().__init__(conn_id=exasol_conn_id, handler=handler, **kwargs)
88 changes: 75 additions & 13 deletions tests/providers/exasol/hooks/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#
from __future__ import annotations

from typing import Any
from unittest import mock
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -57,8 +58,11 @@ def exasol_hook():
return ExasolHook()


def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
return [(field,) for field in fields]
def get_columns(fields: list[str]) -> dict[str, dict[str, Any]]:
return {
field: {"type": "VARCHAR", "nullable": True, "precision": None, "scale": None, "length": None}
for field in fields
}


index = 0
Expand All @@ -75,7 +79,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[
[
("id", "VARCHAR", None, None, None, None, True),
("value", "VARCHAR", None, None, None, None, True),
]
],
[[1, 2], [11, 12]],
id="The return_last set and no split statements set on single query in string",
),
Expand All @@ -86,7 +95,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test;"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[
[
("id", "VARCHAR", None, None, None, None, True),
("value", "VARCHAR", None, None, None, None, True),
]
],
[[1, 2], [11, 12]],
id="The return_last not set and no split statements set on single query in string",
),
Expand All @@ -97,7 +111,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test;"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[
[
("id", "VARCHAR", None, None, None, None, True),
("value", "VARCHAR", None, None, None, None, True),
]
],
[[1, 2], [11, 12]],
id="The return_last set and split statements set on single query in string",
),
Expand All @@ -108,7 +127,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test;"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[
[
("id", "VARCHAR", None, None, None, None, True),
("value", "VARCHAR", None, None, None, None, True),
]
],
[[[1, 2], [11, 12]]],
id="The return_last not set and split statements set on single query in string",
),
Expand All @@ -119,7 +143,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test;", "select * from test.test2;"],
[["id", "value"], ["id2", "value2"]],
([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
[[("id2",), ("value2",)]],
[
[
("id2", "VARCHAR", None, None, None, None, True),
("value2", "VARCHAR", None, None, None, None, True),
]
],
[[3, 4], [13, 14]],
id="The return_last set and split statements set on multiple queries in string",
), # Failing
Expand All @@ -130,7 +159,16 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test;", "select * from test.test2;"],
[["id", "value"], ["id2", "value2"]],
([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
[
("id", "VARCHAR", None, None, None, None, True),
("value", "VARCHAR", None, None, None, None, True),
],
[
("id2", "VARCHAR", None, None, None, None, True),
("value2", "VARCHAR", None, None, None, None, True),
],
],
[[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
id="The return_last not set and split statements set on multiple queries in string",
),
Expand All @@ -141,7 +179,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[
[
("id", "VARCHAR", None, None, None, None, True),
("value", "VARCHAR", None, None, None, None, True),
]
],
[[[1, 2], [11, 12]]],
id="The return_last set on single query in list",
),
Expand All @@ -152,7 +195,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test"],
[["id", "value"]],
([[1, 2], [11, 12]],),
[[("id",), ("value",)]],
[
[
("id", "VARCHAR", None, None, None, None, True),
("value", "VARCHAR", None, None, None, None, True),
]
],
[[[1, 2], [11, 12]]],
id="The return_last not set on single query in list",
),
Expand All @@ -163,7 +211,12 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test", "select * from test.test2"],
[["id", "value"], ["id2", "value2"]],
([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
[[("id2",), ("value2",)]],
[
[
("id2", "VARCHAR", None, None, None, None, True),
("value2", "VARCHAR", None, None, None, None, True),
]
],
[[3, 4], [13, 14]],
id="The return_last set set on multiple queries in list",
),
Expand All @@ -174,7 +227,16 @@ def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
["select * from test.test", "select * from test.test2"],
[["id", "value"], ["id2", "value2"]],
([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
[
("id", "VARCHAR", None, None, None, None, True),
("value", "VARCHAR", None, None, None, None, True),
],
[
("id2", "VARCHAR", None, None, None, None, True),
("value2", "VARCHAR", None, None, None, None, True),
],
],
[[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
id="The return_last not set on multiple queries not set",
),
Expand All @@ -196,8 +258,8 @@ def test_query(
for index in range(len(cursor_descriptions)):
cur = mock.MagicMock(
rowcount=len(cursor_results[index]),
description=get_cursor_descriptions(cursor_descriptions[index]),
)
cur.columns.return_value = get_columns(cursor_descriptions[index])
cur.fetchall.return_value = cursor_results[index]
cursors.append(cur)
mock_conn.execute.side_effect = cursors
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/exasol/operators/test_exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from unittest import mock

from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler
from airflow.providers.exasol.operators.exasol import ExasolOperator


Expand All @@ -32,7 +32,7 @@ def test_overwrite_autocommit(self, mock_get_db_hook):
sql="SELECT 1",
autocommit=True,
parameters=None,
handler=fetch_all_handler,
handler=exasol_fetch_all_handler,
return_last=True,
)

Expand All @@ -44,7 +44,7 @@ def test_pass_parameters(self, mock_get_db_hook):
sql="SELECT {value!s}",
autocommit=False,
parameters={"value": 1},
handler=fetch_all_handler,
handler=exasol_fetch_all_handler,
return_last=True,
)

Expand Down
4 changes: 2 additions & 2 deletions tests/providers/exasol/operators/test_exasol_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import pytest

from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.exasol.hooks.exasol import exasol_fetch_all_handler
from airflow.providers.exasol.operators.exasol import ExasolOperator

DATE = "2017-04-20"
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
dbapi_hook.run.assert_called_once_with(
sql=sql,
parameters=None,
handler=fetch_all_handler,
handler=exasol_fetch_all_handler,
autocommit=False,
return_last=return_last,
split_statements=split_statement,
Expand Down

0 comments on commit f2ab384

Please sign in to comment.