diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml index 218ed550412b3..73ee1cd450831 100644 --- a/airflow-core/pyproject.toml +++ b/airflow-core/pyproject.toml @@ -161,6 +161,9 @@ dependencies = [ # Start of shared configuration dependencies "pyyaml>=6.0.3", # End of shared configuration dependencies + # Start of shared listeners dependencies + "pluggy>=1.5.0", + # End of shared listeners dependencies ] @@ -235,6 +238,7 @@ exclude = [ "../shared/secrets_backend/src/airflow_shared/secrets_backend" = "src/airflow/_shared/secrets_backend" "../shared/secrets_masker/src/airflow_shared/secrets_masker" = "src/airflow/_shared/secrets_masker" "../shared/timezones/src/airflow_shared/timezones" = "src/airflow/_shared/timezones" +"../shared/listeners/src/airflow_shared/listeners" = "src/airflow/_shared/listeners" "../shared/plugins_manager/src/airflow_shared/plugins_manager" = "src/airflow/_shared/plugins_manager" [tool.hatch.build.targets.custom] @@ -305,6 +309,7 @@ apache-airflow-devel-common = { workspace = true } shared_distributions = [ "apache-airflow-shared-configuration", "apache-airflow-shared-dagnode", + "apache-airflow-shared-listeners", "apache-airflow-shared-logging", "apache-airflow-shared-module-loading", "apache-airflow-shared-observability", diff --git a/airflow-core/src/airflow/_shared/listeners b/airflow-core/src/airflow/_shared/listeners new file mode 120000 index 0000000000000..54346425d3717 --- /dev/null +++ b/airflow-core/src/airflow/_shared/listeners @@ -0,0 +1 @@ +../../../../shared/listeners/src/airflow_shared/listeners \ No newline at end of file diff --git a/airflow-core/src/airflow/listeners/__init__.py b/airflow-core/src/airflow/listeners/__init__.py index 87840b50e2fa5..670ecde854c6d 100644 --- a/airflow-core/src/airflow/listeners/__init__.py +++ b/airflow-core/src/airflow/listeners/__init__.py @@ -17,6 +17,6 @@ # under the License. from __future__ import annotations -from pluggy import HookimplMarker +from airflow._shared.listeners import hookimpl -hookimpl = HookimplMarker("airflow") +__all__ = ["hookimpl"] diff --git a/airflow-core/src/airflow/listeners/listener.py b/airflow-core/src/airflow/listeners/listener.py index 08869f5094750..06be7a1b9b908 100644 --- a/airflow-core/src/airflow/listeners/listener.py +++ b/airflow-core/src/airflow/listeners/listener.py @@ -17,72 +17,36 @@ # under the License. from __future__ import annotations -import logging from functools import cache -from typing import TYPE_CHECKING - -import pluggy +from airflow._shared.listeners.listener import ListenerManager +from airflow._shared.listeners.spec import lifecycle, taskinstance +from airflow.listeners.spec import asset, dagrun, importerrors from airflow.plugins_manager import integrate_listener_plugins -if TYPE_CHECKING: - from pluggy._hooks import _HookRelay - -log = logging.getLogger(__name__) - - -def _before_hookcall(hook_name, hook_impls, kwargs): - log.debug("Calling %r with %r", hook_name, kwargs) - log.debug("Hook impls: %s", hook_impls) - - -def _after_hookcall(outcome, hook_name, hook_impls, kwargs): - log.debug("Result from %r: %s", hook_name, outcome.get_result()) - - -class ListenerManager: - """Manage listener registration and provides hook property for calling them.""" - - def __init__(self): - from airflow.listeners.spec import ( - asset, - dagrun, - importerrors, - lifecycle, - taskinstance, - ) - - self.pm = pluggy.PluginManager("airflow") - self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall) - self.pm.add_hookspecs(lifecycle) - self.pm.add_hookspecs(dagrun) - self.pm.add_hookspecs(asset) - self.pm.add_hookspecs(taskinstance) - self.pm.add_hookspecs(importerrors) - - @property - def has_listeners(self) -> bool: - return bool(self.pm.get_plugins()) - - @property - def hook(self) -> _HookRelay: - """Return hook, on which plugin methods specified in spec can be called.""" - return self.pm.hook - - def add_listener(self, listener): - if self.pm.is_registered(listener): - return - self.pm.register(listener) - - def clear(self): - """Remove registered plugins.""" - for plugin in self.pm.get_plugins(): - self.pm.unregister(plugin) - @cache def get_listener_manager() -> ListenerManager: - """Get singleton listener manager.""" + """ + Get a listener manager for Airflow core. + + Registers the following listeners: + - lifecycle: on_starting, before_stopping + - dagrun: on_dag_run_running, on_dag_run_success, on_dag_run_failed + - taskinstance: on_task_instance_running, on_task_instance_success, etc. + - asset: on_asset_created, on_asset_changed, etc. + - importerrors: on_new_dag_import_error, on_existing_dag_import_error + """ _listener_manager = ListenerManager() + + _listener_manager.add_hookspecs(lifecycle) + _listener_manager.add_hookspecs(dagrun) + _listener_manager.add_hookspecs(taskinstance) + _listener_manager.add_hookspecs(asset) + _listener_manager.add_hookspecs(importerrors) + integrate_listener_plugins(_listener_manager) return _listener_manager + + +__all__ = ["get_listener_manager", "ListenerManager"] diff --git a/airflow-core/src/airflow/listeners/spec/asset.py b/airflow-core/src/airflow/listeners/spec/asset.py index 25d1aacf15ff4..05ba0809bcd8a 100644 --- a/airflow-core/src/airflow/listeners/spec/asset.py +++ b/airflow-core/src/airflow/listeners/spec/asset.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations from typing import TYPE_CHECKING diff --git a/airflow-core/src/airflow/listeners/spec/importerrors.py b/airflow-core/src/airflow/listeners/spec/importerrors.py index 2cb2b4e454d37..048fb38ffa109 100644 --- a/airflow-core/src/airflow/listeners/spec/importerrors.py +++ b/airflow-core/src/airflow/listeners/spec/importerrors.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations from pluggy import HookspecMarker diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py index 87a7b932b749b..ef0e4f792e950 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py @@ -27,7 +27,6 @@ from airflow._shared.timezones import timezone from airflow.api_fastapi.core_api.datamodels.dag_versions import DagVersionResponse -from airflow.listeners.listener import get_listener_manager from airflow.models import DagModel, DagRun, Log from airflow.models.asset import AssetEvent, AssetModel from airflow.providers.standard.operators.empty import EmptyOperator @@ -1285,12 +1284,6 @@ def test_patch_dag_run_bad_request(self, test_client): body = response.json() assert body["detail"][0]["msg"] == "Input should be 'queued', 'success' or 'failed'" - @pytest.fixture(autouse=True) - def clean_listener_manager(self): - get_listener_manager().clear() - yield - get_listener_manager().clear() - @pytest.mark.parametrize( ("state", "listener_state"), [ @@ -1300,11 +1293,11 @@ def clean_listener_manager(self): ], ) @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle") - def test_patch_dag_run_notifies_listeners(self, test_client, state, listener_state): + def test_patch_dag_run_notifies_listeners(self, test_client, state, listener_state, listener_manager): from unit.listeners.class_listener import ClassBasedListener listener = ClassBasedListener() - get_listener_manager().add_listener(listener) + listener_manager(listener) response = test_client.patch(f"/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", json={"state": state}) assert response.status_code == 200 assert listener.state == listener_state diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index 136978e238101..63601f5cb02c9 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -33,7 +33,6 @@ from airflow.dag_processing.dagbag import DagBag, sync_bag_to_db from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner -from airflow.listeners.listener import get_listener_manager from airflow.models import DagRun, Log, TaskInstance from airflow.models.dag_version import DagVersion from airflow.models.hitl import HITLDetail @@ -4084,12 +4083,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint): TASK_ID = "print_the_context" RUN_ID = "TEST_DAG_RUN_ID" - @pytest.fixture(autouse=True) - def clean_listener_manager(self): - get_listener_manager().clear() - yield - get_listener_manager().clear() - @pytest.mark.parametrize( ("state", "listener_state"), [ @@ -4098,13 +4091,15 @@ def clean_listener_manager(self): ("skipped", []), ], ) - def test_patch_task_instance_notifies_listeners(self, test_client, session, state, listener_state): + def test_patch_task_instance_notifies_listeners( + self, test_client, session, state, listener_state, listener_manager + ): from unit.listeners.class_listener import ClassBasedListener self.create_task_instances(session) listener = ClassBasedListener() - get_listener_manager().add_listener(listener) + listener_manager(listener) test_client.patch( self.ENDPOINT_URL, json={ diff --git a/airflow-core/tests/unit/assets/test_manager.py b/airflow-core/tests/unit/assets/test_manager.py index 8036b6d8352b8..7929ae7c0c5f7 100644 --- a/airflow-core/tests/unit/assets/test_manager.py +++ b/airflow-core/tests/unit/assets/test_manager.py @@ -30,7 +30,6 @@ from airflow import settings from airflow.assets.manager import AssetManager -from airflow.listeners.listener import get_listener_manager from airflow.models.asset import ( AssetAliasModel, AssetDagRunQueue, @@ -183,11 +182,11 @@ def test_register_asset_change_no_downstreams(self, session, mock_task_instance) assert session.scalar(select(func.count()).select_from(AssetDagRunQueue)) == 0 def test_register_asset_change_notifies_asset_listener( - self, session, mock_task_instance, testing_dag_bundle + self, session, mock_task_instance, testing_dag_bundle, listener_manager ): asset_manager = AssetManager() asset_listener.clear() - get_listener_manager().add_listener(asset_listener) + listener_manager(asset_listener) bundle_name = "testing" @@ -207,10 +206,10 @@ def test_register_asset_change_notifies_asset_listener( assert len(asset_listener.changed) == 1 assert asset_listener.changed[0].uri == asset.uri - def test_create_assets_notifies_asset_listener(self, session): + def test_create_assets_notifies_asset_listener(self, session, listener_manager): asset_manager = AssetManager() asset_listener.clear() - get_listener_manager().add_listener(asset_listener) + listener_manager(asset_listener) asset = Asset(uri="test://asset1", name="test_asset_1") diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py b/airflow-core/tests/unit/dag_processing/test_collection.py index ad4a8a73382e6..458f773046ef3 100644 --- a/airflow-core/tests/unit/dag_processing/test_collection.py +++ b/airflow-core/tests/unit/dag_processing/test_collection.py @@ -41,7 +41,6 @@ update_dag_parsing_results_in_db, ) from airflow.exceptions import SerializationError -from airflow.listeners.listener import get_listener_manager from airflow.models import DagModel, DagRun from airflow.models.asset import ( AssetActive, @@ -321,12 +320,11 @@ def clean_db(self, session): clear_db_import_errors() @pytest.fixture(name="dag_import_error_listener") - def _dag_import_error_listener(self): + def _dag_import_error_listener(self, listener_manager): from unit.listeners import dag_import_error_listener - get_listener_manager().add_listener(dag_import_error_listener) + listener_manager(dag_import_error_listener) yield dag_import_error_listener - get_listener_manager().clear() dag_import_error_listener.clear() @mark_fab_auth_manager_test diff --git a/airflow-core/tests/unit/jobs/test_base_job.py b/airflow-core/tests/unit/jobs/test_base_job.py index f8f780fea0773..fbae2a96837ee 100644 --- a/airflow-core/tests/unit/jobs/test_base_job.py +++ b/airflow-core/tests/unit/jobs/test_base_job.py @@ -28,7 +28,6 @@ from airflow._shared.timezones import timezone from airflow.executors.local_executor import LocalExecutor from airflow.jobs.job import Job, health_check_threshold, most_recent_job, perform_heartbeat, run_job -from airflow.listeners.listener import get_listener_manager from airflow.utils.session import create_session from airflow.utils.state import State @@ -68,11 +67,11 @@ def test_base_job_respects_plugin_hooks(self): assert job.state == State.SUCCESS assert job.end_date is not None - def test_base_job_respects_plugin_lifecycle(self, dag_maker): + def test_base_job_respects_plugin_lifecycle(self, dag_maker, listener_manager): """ Test if DagRun is successful, and if Success callbacks is defined, it is sent to DagFileProcessor. """ - get_listener_manager().add_listener(lifecycle_listener) + listener_manager(lifecycle_listener) job = Job() job_runner = MockJobRunner(job=job, func=lambda: sys.exit(0)) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 74932d3d1ab83..2b9ed36c7cb4d 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -114,7 +114,6 @@ from tests_common.test_utils.mock_operators import CustomOperator from tests_common.test_utils.taskinstance import create_task_instance, run_task_instance from unit.listeners import dag_listener -from unit.listeners.test_listeners import get_listener_manager from unit.models import TEST_DAGS_FOLDER if TYPE_CHECKING: @@ -3190,7 +3189,9 @@ def test_dagrun_callbacks_are_called(self, state, expected_callback_msg, dag_mak ("state", "expected_callback_msg"), [(State.SUCCESS, "success"), (State.FAILED, "task_failure")] ) @conf_vars({("scheduler", "use_job_schedule"): "False"}) - def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_maker, session): + def test_dagrun_plugins_are_notified( + self, state, expected_callback_msg, dag_maker, session, listener_manager + ): """ Test if DagRun is successful, and if Success callbacks is defined, it is sent to DagFileProcessor. """ @@ -3203,7 +3204,7 @@ def test_dagrun_plugins_are_notified(self, state, expected_callback_msg, dag_mak EmptyOperator(task_id="dummy") dag_listener.clear() - get_listener_manager().add_listener(dag_listener) + listener_manager(dag_listener) scheduler_job = Job(executor=self.null_exec) self.job_runner = SchedulerJobRunner(job=scheduler_job) @@ -3374,7 +3375,7 @@ def test_dagrun_callbacks_are_added_when_callbacks_are_defined(self, state, msg, session.close() @conf_vars({("scheduler", "use_job_schedule"): "False"}) - def test_dagrun_notify_called_success(self, dag_maker): + def test_dagrun_notify_called_success(self, dag_maker, listener_manager): with dag_maker( dag_id="test_dagrun_notify_called", on_success_callback=lambda x: print("success"), @@ -3383,7 +3384,7 @@ def test_dagrun_notify_called_success(self, dag_maker): EmptyOperator(task_id="dummy") dag_listener.clear() - get_listener_manager().add_listener(dag_listener) + listener_manager(dag_listener) executor = MockExecutor(do_update=False) diff --git a/airflow-core/tests/unit/listeners/test_asset_listener.py b/airflow-core/tests/unit/listeners/test_asset_listener.py index 3b5e933f7d806..b2ce78c244361 100644 --- a/airflow-core/tests/unit/listeners/test_asset_listener.py +++ b/airflow-core/tests/unit/listeners/test_asset_listener.py @@ -18,7 +18,6 @@ import pytest -from airflow.listeners.listener import get_listener_manager from airflow.models.asset import AssetModel from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.definitions.asset import Asset @@ -28,19 +27,18 @@ @pytest.fixture(autouse=True) -def clean_listener_manager(): - lm = get_listener_manager() - lm.clear() - lm.add_listener(asset_listener) +def clean_listener_state(): + """Clear listener state after each test.""" yield - lm = get_listener_manager() - lm.clear() asset_listener.clear() @pytest.mark.db_test @provide_session -def test_asset_listener_on_asset_changed_gets_calls(create_task_instance_of_operator, session): +def test_asset_listener_on_asset_changed_gets_calls( + create_task_instance_of_operator, session, listener_manager +): + listener_manager(asset_listener) asset_uri = "test://asset/" asset_name = "test_asset_uri" asset_group = "test-group" diff --git a/airflow-core/tests/unit/listeners/test_listeners.py b/airflow-core/tests/unit/listeners/test_listeners.py index 5a2a9ff8bb7c2..aad2ea7b6e863 100644 --- a/airflow-core/tests/unit/listeners/test_listeners.py +++ b/airflow-core/tests/unit/listeners/test_listeners.py @@ -25,7 +25,6 @@ from airflow._shared.timezones import timezone from airflow.exceptions import AirflowException from airflow.jobs.job import Job, run_job -from airflow.listeners.listener import get_listener_manager from airflow.providers.standard.operators.bash import BashOperator from airflow.utils.session import provide_session from airflow.utils.state import DagRunState, TaskInstanceState @@ -58,20 +57,16 @@ @pytest.fixture(autouse=True) -def clean_listener_manager(): - lm = get_listener_manager() - lm.clear() +def clean_listener_state(): + """Clear listener state after each test.""" yield - lm = get_listener_manager() - lm.clear() for listener in LISTENERS: listener.clear() @provide_session -def test_listener_gets_calls(create_task_instance, session): - lm = get_listener_manager() - lm.add_listener(full_listener) +def test_listener_gets_calls(create_task_instance, session, listener_manager): + listener_manager(full_listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) # Using ti.run() instead of ti._run_raw_task() to capture state change to RUNNING @@ -84,12 +79,11 @@ def test_listener_gets_calls(create_task_instance, session): @provide_session -def test_multiple_listeners(create_task_instance, session): - lm = get_listener_manager() - lm.add_listener(full_listener) - lm.add_listener(lifecycle_listener) +def test_multiple_listeners(create_task_instance, session, listener_manager): + listener_manager(full_listener) + listener_manager(lifecycle_listener) class_based_listener = class_listener.ClassBasedListener() - lm.add_listener(class_based_listener) + listener_manager(class_based_listener) job = Job() job_runner = MockJobRunner(job=job) @@ -105,9 +99,8 @@ def test_multiple_listeners(create_task_instance, session): @provide_session -def test_listener_gets_only_subscribed_calls(create_task_instance, session): - lm = get_listener_manager() - lm.add_listener(partial_listener) +def test_listener_gets_only_subscribed_calls(create_task_instance, session, listener_manager): + listener_manager(partial_listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) # Using ti.run() instead of ti._run_raw_task() to capture state change to RUNNING @@ -120,9 +113,8 @@ def test_listener_gets_only_subscribed_calls(create_task_instance, session): @provide_session -def test_listener_suppresses_exceptions(create_task_instance, session, cap_structlog): - lm = get_listener_manager() - lm.add_listener(throwing_listener) +def test_listener_suppresses_exceptions(create_task_instance, session, cap_structlog, listener_manager): + listener_manager(throwing_listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) ti.run() @@ -130,9 +122,8 @@ def test_listener_suppresses_exceptions(create_task_instance, session, cap_struc @provide_session -def test_listener_captures_failed_taskinstances(create_task_instance_of_operator, session): - lm = get_listener_manager() - lm.add_listener(full_listener) +def test_listener_captures_failed_taskinstances(create_task_instance_of_operator, session, listener_manager): + listener_manager(full_listener) ti = create_task_instance_of_operator( BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE, task_id=TASK_ID, bash_command="exit 1" @@ -145,9 +136,10 @@ def test_listener_captures_failed_taskinstances(create_task_instance_of_operator @provide_session -def test_listener_captures_longrunning_taskinstances(create_task_instance_of_operator, session): - lm = get_listener_manager() - lm.add_listener(full_listener) +def test_listener_captures_longrunning_taskinstances( + create_task_instance_of_operator, session, listener_manager +): + listener_manager(full_listener) ti = create_task_instance_of_operator( BashOperator, dag_id=DAG_ID, logical_date=LOGICAL_DATE, task_id=TASK_ID, bash_command="sleep 5" @@ -159,10 +151,9 @@ def test_listener_captures_longrunning_taskinstances(create_task_instance_of_ope @provide_session -def test_class_based_listener(create_task_instance, session): - lm = get_listener_manager() +def test_class_based_listener(create_task_instance, session, listener_manager): listener = class_listener.ClassBasedListener() - lm.add_listener(listener) + listener_manager(listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) ti.run() @@ -170,16 +161,15 @@ def test_class_based_listener(create_task_instance, session): assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS, DagRunState.SUCCESS] -def test_listener_logs_call(caplog, create_task_instance, session): - caplog.set_level(logging.DEBUG, logger="airflow.listeners.listener") - lm = get_listener_manager() - lm.add_listener(full_listener) +def test_listener_logs_call(caplog, create_task_instance, session, listener_manager): + caplog.set_level(logging.DEBUG, logger="airflow.sdk._shared.listeners.listener") + listener_manager(full_listener) ti = create_task_instance(session=session, state=TaskInstanceState.QUEUED) ti.run() - listener_logs = [r for r in caplog.record_tuples if r[0] == "airflow.listeners.listener"] - assert all(r[:-1] == ("airflow.listeners.listener", logging.DEBUG) for r in listener_logs) + listener_logs = [r for r in caplog.record_tuples if r[0] == "airflow.sdk._shared.listeners.listener"] + assert all(r[:-1] == ("airflow.sdk._shared.listeners.listener", logging.DEBUG) for r in listener_logs) assert listener_logs[0][-1].startswith("Calling 'on_task_instance_running' with {'") assert listener_logs[1][-1].startswith("Hook impls: [=1.5.0", + "structlog>=25.4.0", +] + +[dependency-groups] +dev = [ + "apache-airflow-devel-common", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/airflow_shared"] + +[tool.ruff] +extend = "../../pyproject.toml" +src = ["src"] + +[tool.ruff.lint.per-file-ignores] +# Ignore Doc rules et al for anything outside of tests +"!src/*" = ["D", "S101", "TRY002"] + +[tool.ruff.lint.flake8-tidy-imports] +# Override the workspace level default +ban-relative-imports = "parents" diff --git a/shared/listeners/src/airflow_shared/listeners/__init__.py b/shared/listeners/src/airflow_shared/listeners/__init__.py new file mode 100644 index 0000000000000..87840b50e2fa5 --- /dev/null +++ b/shared/listeners/src/airflow_shared/listeners/__init__.py @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from pluggy import HookimplMarker + +hookimpl = HookimplMarker("airflow") diff --git a/shared/listeners/src/airflow_shared/listeners/listener.py b/shared/listeners/src/airflow_shared/listeners/listener.py new file mode 100644 index 0000000000000..d4b36c059d480 --- /dev/null +++ b/shared/listeners/src/airflow_shared/listeners/listener.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pluggy +import structlog + +if TYPE_CHECKING: + from pluggy._hooks import _HookRelay + +log = structlog.get_logger(__name__) + + +def _before_hookcall(hook_name, hook_impls, kwargs): + log.debug("Calling %r with %r", hook_name, kwargs) + log.debug("Hook impls: %s", hook_impls) + + +def _after_hookcall(outcome, hook_name, hook_impls, kwargs): + log.debug("Result from %r: %s", hook_name, outcome.get_result()) + + +class ListenerManager: + """ + Manage listener registration and provides hook property for calling them. + + This class provides base infra for listener system. The consumers / components + wanting to register listeners should initialise its own ListenerManager and + register the hook specs relevant to that component using add_hookspecs. + """ + + def __init__(self): + self.pm = pluggy.PluginManager("airflow") + self.pm.add_hookcall_monitoring(_before_hookcall, _after_hookcall) + + def add_hookspecs(self, spec_module) -> None: + """ + Register hook specs from a module. + + :param spec_module: A module containing functions decorated with @hookspec. + """ + self.pm.add_hookspecs(spec_module) + + @property + def has_listeners(self) -> bool: + return bool(self.pm.get_plugins()) + + @property + def hook(self) -> _HookRelay: + """Return hook, on which plugin methods specified in spec can be called.""" + return self.pm.hook + + def add_listener(self, listener): + if self.pm.is_registered(listener): + return + self.pm.register(listener) + + def clear(self): + """Remove registered plugins.""" + for plugin in self.pm.get_plugins(): + self.pm.unregister(plugin) diff --git a/shared/listeners/src/airflow_shared/listeners/spec/__init__.py b/shared/listeners/src/airflow_shared/listeners/spec/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/shared/listeners/src/airflow_shared/listeners/spec/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow-core/src/airflow/listeners/spec/lifecycle.py b/shared/listeners/src/airflow_shared/listeners/spec/lifecycle.py similarity index 100% rename from airflow-core/src/airflow/listeners/spec/lifecycle.py rename to shared/listeners/src/airflow_shared/listeners/spec/lifecycle.py diff --git a/airflow-core/src/airflow/listeners/spec/taskinstance.py b/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py similarity index 90% rename from airflow-core/src/airflow/listeners/spec/taskinstance.py rename to shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py index 75b98e8a7b505..d3450d6b05aa7 100644 --- a/airflow-core/src/airflow/listeners/spec/taskinstance.py +++ b/shared/listeners/src/airflow_shared/listeners/spec/taskinstance.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations from typing import TYPE_CHECKING @@ -22,6 +23,7 @@ from pluggy import HookspecMarker if TYPE_CHECKING: + # These imports are for type checking only - no runtime dependency from airflow.models.taskinstance import TaskInstance from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.utils.state import TaskInstanceState @@ -30,13 +32,17 @@ @hookspec -def on_task_instance_running(previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance): +def on_task_instance_running( + previous_state: TaskInstanceState | None, + task_instance: RuntimeTaskInstance | TaskInstance, +): """Execute when task state changes to RUNNING. previous_state can be None.""" @hookspec def on_task_instance_success( - previous_state: TaskInstanceState | None, task_instance: RuntimeTaskInstance | TaskInstance + previous_state: TaskInstanceState | None, + task_instance: RuntimeTaskInstance | TaskInstance, ): """Execute when task state changes to SUCCESS. previous_state can be None.""" diff --git a/shared/listeners/tests/conftest.py b/shared/listeners/tests/conftest.py new file mode 100644 index 0000000000000..8b61b1b99b90d --- /dev/null +++ b/shared/listeners/tests/conftest.py @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os + +os.environ["_AIRFLOW__AS_LIBRARY"] = "true" diff --git a/shared/listeners/tests/listeners/__init__.py b/shared/listeners/tests/listeners/__init__.py new file mode 100644 index 0000000000000..03cb33c14c40e --- /dev/null +++ b/shared/listeners/tests/listeners/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations diff --git a/shared/listeners/tests/listeners/test_listener_manager.py b/shared/listeners/tests/listeners/test_listener_manager.py new file mode 100644 index 0000000000000..ebf360dade08d --- /dev/null +++ b/shared/listeners/tests/listeners/test_listener_manager.py @@ -0,0 +1,164 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow_shared.listeners import hookimpl +from airflow_shared.listeners.listener import ListenerManager +from airflow_shared.listeners.spec import lifecycle, taskinstance + + +class TestListenerManager: + def test_initial_state_has_no_listeners(self): + """Test that a new ListenerManager has no listeners.""" + lm = ListenerManager() + assert not lm.has_listeners + assert len(lm.pm.get_plugins()) == 0 + + def test_add_hookspecs_registers_hooks(self): + """Test that add_hookspecs makes hooks available.""" + lm = ListenerManager() + lm.add_hookspecs(lifecycle) + + # Verify lifecycle hooks are now available + assert hasattr(lm.hook, "on_starting") + assert hasattr(lm.hook, "before_stopping") + + def test_add_multiple_hookspecs(self): + """Test that multiple hookspecs can be registered.""" + lm = ListenerManager() + lm.add_hookspecs(lifecycle) + lm.add_hookspecs(taskinstance) + + # Verify hooks from both specs are available + assert hasattr(lm.hook, "on_starting") + assert hasattr(lm.hook, "on_task_instance_running") + + def test_add_listener(self): + """Test listener registration.""" + + class TestListener: + def __init__(self): + self.called = False + + @hookimpl + def on_starting(self, component): + self.called = True + + lm = ListenerManager() + lm.add_hookspecs(lifecycle) + listener = TestListener() + lm.add_listener(listener) + + assert lm.has_listeners + assert lm.pm.is_registered(listener) + + def test_duplicate_listener_registration(self): + """Test adding same listener twice doesn't duplicate.""" + + class TestListener: + @hookimpl + def on_starting(self, component): + pass + + lm = ListenerManager() + lm.add_hookspecs(lifecycle) + listener = TestListener() + lm.add_listener(listener) + lm.add_listener(listener) + + # Should only be registered once + assert len(lm.pm.get_plugins()) == 1 + + def test_clear_listeners(self): + """Test clearing listeners removes all registered listeners.""" + + class TestListener: + @hookimpl + def on_starting(self, component): + pass + + lm = ListenerManager() + lm.add_hookspecs(lifecycle) + listener1 = TestListener() + listener2 = TestListener() + lm.add_listener(listener1) + lm.add_listener(listener2) + + assert lm.has_listeners + assert len(lm.pm.get_plugins()) == 2 + + lm.clear() + + assert not lm.has_listeners + assert len(lm.pm.get_plugins()) == 0 + + def test_hook_calling(self): + """Test hooks can be called and listeners receive them.""" + + class TestListener: + def __init__(self): + self.component_received = None + + @hookimpl + def on_starting(self, component): + self.component_received = component + + lm = ListenerManager() + lm.add_hookspecs(lifecycle) + listener = TestListener() + lm.add_listener(listener) + + test_component = "test_component" + lm.hook.on_starting(component=test_component) + + assert listener.component_received == test_component + + def test_taskinstance_hooks(self): + """Test taskinstance hook specs work correctly.""" + + class TaskInstanceListener: + def __init__(self): + self.events = [] + + @hookimpl + def on_task_instance_running(self, previous_state, task_instance): + self.events.append(("running", task_instance)) + + @hookimpl + def on_task_instance_success(self, previous_state, task_instance): + self.events.append(("success", task_instance)) + + @hookimpl + def on_task_instance_failed(self, previous_state, task_instance, error): + self.events.append(("failed", task_instance, error)) + + lm = ListenerManager() + lm.add_hookspecs(taskinstance) + listener = TaskInstanceListener() + lm.add_listener(listener) + + mock_ti = "mock_task_instance" + lm.hook.on_task_instance_running(previous_state=None, task_instance=mock_ti) + lm.hook.on_task_instance_success(previous_state=None, task_instance=mock_ti) + lm.hook.on_task_instance_failed(previous_state=None, task_instance=mock_ti, error="test error") + + assert listener.events == [ + ("running", mock_ti), + ("success", mock_ti), + ("failed", mock_ti, "test error"), + ] diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index e1ece8f283110..fc989724391d8 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -77,6 +77,9 @@ dependencies = [ "packaging>=24.0", "typing-extensions>=4.14.1", # End of shared configuration dependencies + # Start of shared listeners dependencies + "pluggy>=1.5.0", + # End of shared listeners dependencies # Start of shared module-loading dependencies 'importlib_metadata>=6.5;python_version<"3.12"', "pathspec>=0.9.0", @@ -123,10 +126,11 @@ path = "src/airflow/sdk/__init__.py" "../shared/dagnode/src/airflow_shared/dagnode" = "src/airflow/sdk/_shared/dagnode" "../shared/logging/src/airflow_shared/logging" = "src/airflow/sdk/_shared/logging" "../shared/module_loading/src/airflow_shared/module_loading" = "src/airflow/sdk/_shared/module_loading" -"../shared/observability/src/airflow_shared/observability" = "src/airflow/_shared/observability" +"../shared/observability/src/airflow_shared/observability" = "src/airflow/sdk/_shared/observability" "../shared/secrets_backend/src/airflow_shared/secrets_backend" = "src/airflow/sdk/_shared/secrets_backend" "../shared/secrets_masker/src/airflow_shared/secrets_masker" = "src/airflow/sdk/_shared/secrets_masker" "../shared/timezones/src/airflow_shared/timezones" = "src/airflow/sdk/_shared/timezones" +"../shared/listeners/src/airflow_shared/listeners" = "src/airflow/sdk/_shared/listeners" "../shared/plugins_manager/src/airflow_shared/plugins_manager" = "src/airflow/sdk/_shared/plugins_manager" [tool.hatch.build.targets.wheel] @@ -271,6 +275,7 @@ tmp_path_retention_policy = "failed" shared_distributions = [ "apache-airflow-shared-configuration", "apache-airflow-shared-dagnode", + "apache-airflow-shared-listeners", "apache-airflow-shared-logging", "apache-airflow-shared-module-loading", "apache-airflow-shared-secrets-backend", diff --git a/task-sdk/src/airflow/sdk/_shared/listeners b/task-sdk/src/airflow/sdk/_shared/listeners new file mode 120000 index 0000000000000..fa2743732061f --- /dev/null +++ b/task-sdk/src/airflow/sdk/_shared/listeners @@ -0,0 +1 @@ +../../../../../shared/listeners/src/airflow_shared/listeners \ No newline at end of file diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 814673d93fe27..cb76097b9fefc 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -40,7 +40,6 @@ from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.client import get_hostname, getuser from airflow.sdk.api.datamodels._generated import ( AssetProfile, @@ -118,6 +117,7 @@ ) from airflow.sdk.execution_time.sentry import Sentry from airflow.sdk.execution_time.xcom import XCom +from airflow.sdk.listener import get_listener_manager from airflow.sdk.observability.stats import Stats from airflow.sdk.timezone import coerce_datetime from airflow.triggers.base import BaseEventTrigger diff --git a/task-sdk/src/airflow/sdk/listener.py b/task-sdk/src/airflow/sdk/listener.py new file mode 100644 index 0000000000000..62c36753ce30a --- /dev/null +++ b/task-sdk/src/airflow/sdk/listener.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cache + +from airflow.sdk._shared.listeners.listener import ListenerManager +from airflow.sdk._shared.listeners.spec import lifecycle, taskinstance +from airflow.sdk.plugins_manager import integrate_listener_plugins + + +@cache +def get_listener_manager() -> ListenerManager: + """ + Get a listener manager for task sdk. + + Registers the following listeners: + - lifecycle: on_starting, before_stopping + - taskinstance: on_task_instance_running, on_task_instance_success, etc. + """ + _listener_manager = ListenerManager() + + _listener_manager.add_hookspecs(lifecycle) + _listener_manager.add_hookspecs(taskinstance) + + integrate_listener_plugins(_listener_manager) # type: ignore[arg-type] + return _listener_manager + + +__all__ = ["get_listener_manager", "ListenerManager"] diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index eadb2e7d59411..c1ef3b72c92f4 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -164,14 +164,14 @@ def _disable_ol_plugin(): # 3.12+ issues a warning when os.fork happens. So for this plugin we disable it # And we load plugins when setting the priority_weight field - import airflow.plugins_manager + import airflow.sdk.plugins_manager - old = airflow.plugins_manager._get_plugins - airflow.plugins_manager._get_plugins = lambda: ([], {}) + old = airflow.sdk.plugins_manager._get_plugins + airflow.sdk.plugins_manager._get_plugins = lambda: ([], {}) yield - airflow.plugins_manager._get_plugins = old + airflow.sdk.plugins_manager._get_plugins = old @pytest.fixture(autouse=True) diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py b/task-sdk/tests/task_sdk/docs/test_public_api.py index f02b9ccebf91f..e7f653d76a263 100644 --- a/task-sdk/tests/task_sdk/docs/test_public_api.py +++ b/task-sdk/tests/task_sdk/docs/test_public_api.py @@ -60,6 +60,7 @@ def test_airflow_sdk_no_unexpected_exports(): "serde", "observability", "plugins_manager", + "listener", } unexpected = actual - public - ignore assert not unexpected, f"Unexpected exports in airflow.sdk: {sorted(unexpected)}" diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 4a5c54ba3a467..e6a179f7189c5 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -35,8 +35,6 @@ from task_sdk import FAKE_BUNDLE from uuid6 import uuid7 -from airflow.listeners import hookimpl -from airflow.listeners.listener import get_listener_manager from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import ( DAG, @@ -48,6 +46,7 @@ task as task_decorator, timezone, ) +from airflow.sdk._shared.listeners import hookimpl from airflow.sdk.api.datamodels._generated import ( AssetProfile, AssetResponse, @@ -459,9 +458,9 @@ def test_defer_task_queue_assignment( ) -def test_run_downstream_skipped(mocked_parse, create_runtime_ti, mock_supervisor_comms): +def test_run_downstream_skipped(mocked_parse, create_runtime_ti, mock_supervisor_comms, listener_manager): listener = TestTaskRunnerCallsListeners.CustomListener() - get_listener_manager().add_listener(listener) + listener_manager(listener) class CustomOperator(BaseOperator): def execute(self, context): @@ -3269,19 +3268,11 @@ def on_task_instance_failed(self, previous_state, task_instance, error): self._add_outlet_events(context) self.error = error - @pytest.fixture(autouse=True) - def clean_listener_manager(self): - lm = get_listener_manager() - lm.clear() - yield - lm = get_listener_manager() - lm.clear() - def test_task_runner_calls_on_startup_before_stopping( - self, make_ti_context, mocked_parse, mock_supervisor_comms + self, make_ti_context, mocked_parse, mock_supervisor_comms, listener_manager ): listener = self.CustomListener() - get_listener_manager().add_listener(listener) + listener_manager(listener) class CustomOperator(BaseOperator): def execute(self, context): @@ -3318,9 +3309,9 @@ def execute(self, context): finalize(runtime_ti, state, context, log) assert isinstance(listener.component, TaskRunnerMarker) - def test_task_runner_calls_listeners_success(self, mocked_parse, mock_supervisor_comms): + def test_task_runner_calls_listeners_success(self, mocked_parse, mock_supervisor_comms, listener_manager): listener = self.CustomListener() - get_listener_manager().add_listener(listener) + listener_manager(listener) class CustomOperator(BaseOperator): def execute(self, context): @@ -3357,9 +3348,11 @@ def execute(self, context): AirflowException("oops"), ], ) - def test_task_runner_calls_listeners_failed(self, mocked_parse, mock_supervisor_comms, exception): + def test_task_runner_calls_listeners_failed( + self, mocked_parse, mock_supervisor_comms, exception, listener_manager + ): listener = self.CustomListener() - get_listener_manager().add_listener(listener) + listener_manager(listener) class CustomOperator(BaseOperator): def execute(self, context): @@ -3389,9 +3382,9 @@ def execute(self, context): assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.FAILED] assert listener.error == error - def test_task_runner_calls_listeners_skipped(self, mocked_parse, mock_supervisor_comms): + def test_task_runner_calls_listeners_skipped(self, mocked_parse, mock_supervisor_comms, listener_manager): listener = self.CustomListener() - get_listener_manager().add_listener(listener) + listener_manager(listener) class CustomOperator(BaseOperator): def execute(self, context): @@ -3420,10 +3413,12 @@ def execute(self, context): assert listener.state == [TaskInstanceState.RUNNING, TaskInstanceState.SKIPPED] - def test_listener_access_outlet_event_on_running_and_success(self, mocked_parse, mock_supervisor_comms): + def test_listener_access_outlet_event_on_running_and_success( + self, mocked_parse, mock_supervisor_comms, listener_manager + ): """Test listener can access outlet events through invoking get_template_context() while task running and success""" listener = self.CustomOutletEventsListener() - get_listener_manager().add_listener(listener) + listener_manager(listener) test_asset = Asset("test-asset") test_key = AssetUniqueKey(name="test-asset", uri="test-asset") @@ -3480,10 +3475,12 @@ def execute(self, context): ], ids=["ValueError", "SystemExit", "AirflowException"], ) - def test_listener_access_outlet_event_on_failed(self, mocked_parse, mock_supervisor_comms, exception): + def test_listener_access_outlet_event_on_failed( + self, mocked_parse, mock_supervisor_comms, exception, listener_manager + ): """Test listener can access outlet events through invoking get_template_context() while task failed""" listener = self.CustomOutletEventsListener() - get_listener_manager().add_listener(listener) + listener_manager(listener) test_asset = Asset("test-asset") test_key = AssetUniqueKey(name="test-asset", uri="test-asset")