Skip to content

Commit

Permalink
37437
Browse files Browse the repository at this point in the history
Allow SqlSensor to inspect the entire result row by adding a selector field.
This is useful to customize the success/failure criteria instead of just the first cell.
  • Loading branch information
Jasmin committed Oct 17, 2024
1 parent 3f8ac22 commit fd000f1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
23 changes: 14 additions & 9 deletions providers/src/airflow/providers/common/sql/sensors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

from operator import itemgetter
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -46,10 +47,12 @@ class SqlSensor(BaseSensorOperator):
:param sql: The SQL to run. To pass, it needs to return at least one cell
that contains a non-zero / empty string value.
:param parameters: The parameters to render the SQL query with (optional).
:param success: Success criteria for the sensor is a Callable that takes the first_cell's value
as the only argument, and returns a boolean (optional).
:param failure: Failure criteria for the sensor is a Callable that takes the first_cell's value
as the only argument and returns a boolean (optional).
:param success: Success criteria for the sensor is a Callable that takes the output
of selector as the only argument, and returns a boolean (optional).
:param failure: Failure criteria for the sensor is a Callable that takes the output
of selector as the only argument and returns a boolean (optional).
:param selector: Function which takes the resulting row and transforms it before
it is passed to success or failure (optional). Takes the first cell by default.
:param fail_on_empty: Explicitly fail on no rows returned.
:param hook_params: Extra config params to be passed to the underlying hook.
Should match the desired hook constructor params.
Expand All @@ -67,6 +70,7 @@ def __init__(
parameters: Mapping[str, Any] | None = None,
success: Callable[[Any], bool] | None = None,
failure: Callable[[Any], bool] | None = None,
selector: Callable[[tuple[Any]], Any] | None = itemgetter(0),
fail_on_empty: bool = False,
hook_params: Mapping[str, Any] | None = None,
**kwargs,
Expand All @@ -76,6 +80,7 @@ def __init__(
self.parameters = parameters
self.success = success
self.failure = failure
self.selector = selector
self.fail_on_empty = fail_on_empty
self.hook_params = hook_params
super().__init__(**kwargs)
Expand All @@ -102,20 +107,20 @@ def poke(self, context: Context) -> bool:
else:
return False

first_cell = records[0][0]
condition = self.selector(records[0])
if self.failure is not None:
if callable(self.failure):
if self.failure(first_cell):
message = f"Failure criteria met. self.failure({first_cell}) returned True"
if self.failure(condition):
message = f"Failure criteria met. self.failure({condition}) returned True"
raise AirflowException(message)
else:
message = f"self.failure is present, but not callable -> {self.failure}"
raise AirflowException(message)

if self.success is not None:
if callable(self.success):
return self.success(first_cell)
return self.success(condition)
else:
message = f"self.success is present, but not callable -> {self.success}"
raise AirflowException(message)
return bool(first_cell)
return bool(condition)
25 changes: 25 additions & 0 deletions providers/tests/common/sql/sensors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,31 @@ def test_sql_sensor_postgres_poke_invalid_success(
op.poke({})
assert "self.success is present, but not callable -> [1]" == str(ctx.value)

@pytest.mark.backend("postgres")
def test_sql_sensor_postgres_with_selector(self):
op1 = SqlSensor(
task_id="sql_sensor_check_1",
conn_id="postgres_default",
sql="SELECT 0, 1",
dag=self.dag,
success=lambda x: x in [1],
failure=lambda x: x in [0],
selector=lambda x: x[1],
)
op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

op2 = SqlSensor(
task_id="sql_sensor_check_2",
conn_id="postgres_default",
sql="SELECT 0, 1",
dag=self.dag,
success=lambda x: x in [1],
failure=lambda x: x in [0],
selector=lambda x: x[0],
)
with pytest.raises(AirflowException):
op2.poke({})

@pytest.mark.db_test
def test_sql_sensor_hook_params(self):
op = SqlSensor(
Expand Down

0 comments on commit fd000f1

Please sign in to comment.