diff --git a/providers/src/airflow/providers/common/sql/sensors/sql.py b/providers/src/airflow/providers/common/sql/sensors/sql.py index ece2ea241c948..e9cac16a00eae 100644 --- a/providers/src/airflow/providers/common/sql/sensors/sql.py +++ b/providers/src/airflow/providers/common/sql/sensors/sql.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from operator import itemgetter from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence from airflow.exceptions import AirflowException @@ -46,10 +47,12 @@ class SqlSensor(BaseSensorOperator): :param sql: The SQL to run. To pass, it needs to return at least one cell that contains a non-zero / empty string value. :param parameters: The parameters to render the SQL query with (optional). - :param success: Success criteria for the sensor is a Callable that takes the first_cell's value - as the only argument, and returns a boolean (optional). - :param failure: Failure criteria for the sensor is a Callable that takes the first_cell's value - as the only argument and returns a boolean (optional). + :param success: Success criteria for the sensor is a Callable that takes the output + of selector as the only argument, and returns a boolean (optional). + :param failure: Failure criteria for the sensor is a Callable that takes the output + of selector as the only argument and returns a boolean (optional). + :param selector: Function which takes the resulting row and transforms it before + it is passed to success or failure (optional). Takes the first cell by default. :param fail_on_empty: Explicitly fail on no rows returned. :param hook_params: Extra config params to be passed to the underlying hook. Should match the desired hook constructor params. @@ -67,6 +70,7 @@ def __init__( parameters: Mapping[str, Any] | None = None, success: Callable[[Any], bool] | None = None, failure: Callable[[Any], bool] | None = None, + selector: Callable[[tuple[Any]], Any] | None = itemgetter(0), fail_on_empty: bool = False, hook_params: Mapping[str, Any] | None = None, **kwargs, @@ -76,6 +80,7 @@ def __init__( self.parameters = parameters self.success = success self.failure = failure + self.selector = selector self.fail_on_empty = fail_on_empty self.hook_params = hook_params super().__init__(**kwargs) @@ -102,11 +107,11 @@ def poke(self, context: Context) -> bool: else: return False - first_cell = records[0][0] + condition = self.selector(records[0]) if self.failure is not None: if callable(self.failure): - if self.failure(first_cell): - message = f"Failure criteria met. self.failure({first_cell}) returned True" + if self.failure(condition): + message = f"Failure criteria met. self.failure({condition}) returned True" raise AirflowException(message) else: message = f"self.failure is present, but not callable -> {self.failure}" @@ -114,8 +119,8 @@ def poke(self, context: Context) -> bool: if self.success is not None: if callable(self.success): - return self.success(first_cell) + return self.success(condition) else: message = f"self.success is present, but not callable -> {self.success}" raise AirflowException(message) - return bool(first_cell) + return bool(condition) diff --git a/providers/tests/common/sql/sensors/test_sql.py b/providers/tests/common/sql/sensors/test_sql.py index 3b46af7be578d..377a51e8295ec 100644 --- a/providers/tests/common/sql/sensors/test_sql.py +++ b/providers/tests/common/sql/sensors/test_sql.py @@ -264,6 +264,31 @@ def test_sql_sensor_postgres_poke_invalid_success( op.poke({}) assert "self.success is present, but not callable -> [1]" == str(ctx.value) + @pytest.mark.backend("postgres") + def test_sql_sensor_postgres_with_selector(self): + op1 = SqlSensor( + task_id="sql_sensor_check_1", + conn_id="postgres_default", + sql="SELECT 0, 1", + dag=self.dag, + success=lambda x: x in [1], + failure=lambda x: x in [0], + selector=lambda x: x[1], + ) + op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + op2 = SqlSensor( + task_id="sql_sensor_check_2", + conn_id="postgres_default", + sql="SELECT 0, 1", + dag=self.dag, + success=lambda x: x in [1], + failure=lambda x: x in [0], + selector=lambda x: x[0], + ) + with pytest.raises(AirflowException): + op2.poke({}) + @pytest.mark.db_test def test_sql_sensor_hook_params(self): op = SqlSensor(