diff --git a/providers/standard/README.rst b/providers/standard/README.rst index baa3ab09a2fb4..65d9b7bf6e61c 100644 --- a/providers/standard/README.rst +++ b/providers/standard/README.rst @@ -23,7 +23,7 @@ Package ``apache-airflow-providers-standard`` -Release: ``1.4.0`` +Release: ``1.4.1`` Airflow Standard Provider diff --git a/providers/standard/docs/changelog.rst b/providers/standard/docs/changelog.rst index bcbba97c71e3e..3fd1138a655ae 100644 --- a/providers/standard/docs/changelog.rst +++ b/providers/standard/docs/changelog.rst @@ -35,6 +35,14 @@ Changelog --------- +1.4.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``Fix sensor skipping in Airflow 3.x branching operators (#53455)`` + 1.4.0 ..... diff --git a/providers/standard/docs/index.rst b/providers/standard/docs/index.rst index d38181c03e5df..b52a262d8e915 100644 --- a/providers/standard/docs/index.rst +++ b/providers/standard/docs/index.rst @@ -66,7 +66,7 @@ apache-airflow-providers-standard package Airflow Standard Provider -Release: 1.4.0 +Release: 1.4.1 Provider package ---------------- diff --git a/providers/standard/provider.yaml b/providers/standard/provider.yaml index f932014c94073..96e0d749ad25c 100644 --- a/providers/standard/provider.yaml +++ b/providers/standard/provider.yaml @@ -27,6 +27,7 @@ source-date-epoch: 1751474457 # In such case adding >= NEW_VERSION and bumping to NEW_VERSION in a provider have # to be done in the same PR versions: + - 1.4.1 - 1.4.0 - 1.3.0 - 1.2.0 diff --git a/providers/standard/pyproject.toml b/providers/standard/pyproject.toml index 3eaba4644cc12..41dcfa12485b2 100644 --- a/providers/standard/pyproject.toml +++ b/providers/standard/pyproject.toml @@ -25,7 +25,7 @@ build-backend = "flit_core.buildapi" [project] name = "apache-airflow-providers-standard" -version = "1.4.0" +version = "1.4.1" description = "Provider package apache-airflow-providers-standard for Apache Airflow" readme = "README.rst" authors = [ @@ -94,8 +94,8 @@ apache-airflow-providers-common-sql = {workspace = true} apache-airflow-providers-standard = {workspace = true} [project.urls] -"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/1.4.0" -"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/1.4.0/changelog.html" +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/1.4.1" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-standard/1.4.1/changelog.html" "Bug Tracker" = "https://github.com/apache/airflow/issues" "Source Code" = "https://github.com/apache/airflow" "Slack Chat" = "https://s.apache.org/airflow-slack" diff --git a/providers/standard/src/airflow/providers/standard/__init__.py b/providers/standard/src/airflow/providers/standard/__init__.py index 081390646e3f2..9f833db0dcf33 100644 --- a/providers/standard/src/airflow/providers/standard/__init__.py +++ b/providers/standard/src/airflow/providers/standard/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "1.4.0" +__version__ = "1.4.1" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.10.0" diff --git a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py index 32eb545f47019..722ce8b6c76ac 100644 --- a/providers/standard/src/airflow/providers/standard/utils/skipmixin.py +++ b/providers/standard/src/airflow/providers/standard/utils/skipmixin.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING from airflow.exceptions import AirflowException -from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS +from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: @@ -45,7 +45,7 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: - if AIRFLOW_V_3_1_PLUS: + if AIRFLOW_V_3_0_PLUS: from airflow.sdk import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator else: diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index ff0cd45a8694a..13dac0897da66 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -842,6 +842,54 @@ def test_xcom_push_skipped_tasks(self): "skipped": ["empty_task"] } + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 2 implementation is different") + def test_short_circuit_operator_skips_sensors(self): + """Test that ShortCircuitOperator properly skips sensors in Airflow 3.x.""" + from airflow.sdk.bases.sensor import BaseSensorOperator + + # Create a sensor similar to S3FileSensor to reproduce the issue + class CustomS3Sensor(BaseSensorOperator): + def __init__(self, bucket_name: str, object_key: str, **kwargs): + super().__init__(**kwargs) + self.bucket_name = bucket_name + self.object_key = object_key + self.timeout = 0 + self.poke_interval = 0 + + def poke(self, context): + # Simulate sensor logic + return True + + with self.dag_maker(self.dag_id): + # ShortCircuit that evaluates to False (should skip all downstream) + short_circuit = ShortCircuitOperator( + task_id="check_dis_is_mon_to_fri_not_holiday", + python_callable=lambda: False, # This causes skipping + ) + + sensor_task = CustomS3Sensor( + task_id="wait_for_ticker_to_secid_lookup_s3_file", + bucket_name="test-bucket", + object_key="ticker_to_secid_lookup.csv", + ) + + short_circuit >> sensor_task + + dr = self.dag_maker.create_dagrun() + + self.dag_maker.run_ti("check_dis_is_mon_to_fri_not_holiday", dr) + + # Verify the sensor is included in the skip list by checking XCom + # (this was the bug - sensors were not being included in skip list) + tis = dr.get_task_instances() + xcom_data = tis[0].xcom_pull(task_ids="check_dis_is_mon_to_fri_not_holiday", key="skipmixin_key") + + assert xcom_data is not None, "XCom data should exist" + skipped_task_ids = set(xcom_data.get("skipped", [])) + assert "wait_for_ticker_to_secid_lookup_s3_file" in skipped_task_ids, ( + "Sensor should be skipped by ShortCircuitOperator" + ) + virtualenv_string_args: list[str] = [] diff --git a/providers/standard/tests/unit/standard/utils/test_skipmixin.py b/providers/standard/tests/unit/standard/utils/test_skipmixin.py index 360a591af10fb..ac1904f14e502 100644 --- a/providers/standard/tests/unit/standard/utils/test_skipmixin.py +++ b/providers/standard/tests/unit/standard/utils/test_skipmixin.py @@ -298,3 +298,84 @@ def test_raise_exception_on_not_valid_branch_task_ids(self, dag_maker, branch_ta error_message = r"'branch_task_ids' must contain only valid task_ids. Invalid tasks found: .*" with pytest.raises(AirflowException, match=error_message): SkipMixin().skip_all_except(ti=ti1, branch_task_ids=branch_task_ids) + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Issue only exists in Airflow 3.x") + def test_ensure_tasks_includes_sensors_airflow_3x(self, dag_maker): + """Test that sensors (inheriting from airflow.sdk.BaseOperator) are properly handled by _ensure_tasks.""" + from airflow.providers.standard.utils.skipmixin import _ensure_tasks + from airflow.sdk import BaseOperator as SDKBaseOperator + from airflow.sdk.bases.sensor import BaseSensorOperator + + class DummySensor(BaseSensorOperator): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.timeout = 0 + self.poke_interval = 0 + + def poke(self, context): + return True + + with dag_maker("dag_test_sensor_skipping") as dag: + regular_task = EmptyOperator(task_id="regular_task") + sensor_task = DummySensor(task_id="sensor_task") + downstream_task = EmptyOperator(task_id="downstream_task") + + regular_task >> [sensor_task, downstream_task] + + dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) + + downstream_nodes = dag.get_task("regular_task").downstream_list + task_list = _ensure_tasks(downstream_nodes) + + # Verify both the regular operator and sensor are included + task_ids = [t.task_id for t in task_list] + assert "sensor_task" in task_ids, "Sensor should be included in task list" + assert "downstream_task" in task_ids, "Regular task should be included in task list" + assert len(task_list) == 2, "Both tasks should be included" + + # Also verify that the sensor is actually an instance of the correct BaseOperator + sensor_in_list = next((t for t in task_list if t.task_id == "sensor_task"), None) + assert sensor_in_list is not None, "Sensor task should be found in list" + assert isinstance(sensor_in_list, SDKBaseOperator), "Sensor should be instance of SDK BaseOperator" + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Integration test for Airflow 3.x sensor skipping") + def test_skip_sensor_in_branching_scenario(self, dag_maker): + """Integration test: verify sensors are properly skipped by branching operators in Airflow 3.x.""" + from airflow.sdk.bases.sensor import BaseSensorOperator + + # Create a dummy sensor for testing + class DummySensor(BaseSensorOperator): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.timeout = 0 + self.poke_interval = 0 + + def poke(self, context): + return True + + with dag_maker("dag_test_branch_sensor_skipping"): + branch_task = EmptyOperator(task_id="branch_task") + regular_task = EmptyOperator(task_id="regular_task") + sensor_task = DummySensor(task_id="sensor_task") + branch_task >> [regular_task, sensor_task] + + dag_maker.create_dagrun(run_id=DEFAULT_DAG_RUN_ID) + + dag_version = DagVersion.get_latest_version(branch_task.dag_id) + ti_branch = TI(branch_task, run_id=DEFAULT_DAG_RUN_ID, dag_version_id=dag_version.id) + + # Test skipping the sensor (follow regular_task branch) + with pytest.raises(DownstreamTasksSkipped) as exc_info: + SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="regular_task") + + # Verify that the sensor task is properly marked for skipping + skipped_tasks = set(exc_info.value.tasks) + assert ("sensor_task", -1) in skipped_tasks, "Sensor task should be marked for skipping" + + # Test skipping the regular task (follow sensor_task branch) + with pytest.raises(DownstreamTasksSkipped) as exc_info: + SkipMixin().skip_all_except(ti=ti_branch, branch_task_ids="sensor_task") + + # Verify that the regular task is properly marked for skipping + skipped_tasks = set(exc_info.value.tasks) + assert ("regular_task", -1) in skipped_tasks, "Regular task should be marked for skipping"