diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 7b5c8e7c4f8c4..0c3d119be1983 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -34,6 +34,7 @@ import warnings from datetime import datetime, timedelta from functools import total_ordering, wraps +from threading import local from types import FunctionType from typing import ( TYPE_CHECKING, @@ -392,6 +393,8 @@ class ExecutorSafeguard: """ test_mode = conf.getboolean("core", "unit_test_mode") + _sentinel = local() + _sentinel.callers = {} @classmethod def decorator(cls, func): @@ -399,7 +402,13 @@ def decorator(cls, func): def wrapper(self, *args, **kwargs): from airflow.decorators.base import DecoratedOperator - sentinel = kwargs.pop(f"{self.__class__.__name__}__sentinel", None) + sentinel_key = f"{self.__class__.__name__}__sentinel" + sentinel = kwargs.pop(sentinel_key, None) + + if sentinel: + cls._sentinel.callers[sentinel_key] = sentinel + else: + sentinel = cls._sentinel.callers.pop(f"{func.__qualname__.split('.')[0]}__sentinel", None) if not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator): message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside TaskInstance!" diff --git a/tests/models/test_baseoperatormeta.py b/tests/models/test_baseoperatormeta.py index 6c6567b23899e..5244e86b2c386 100644 --- a/tests/models/test_baseoperatormeta.py +++ b/tests/models/test_baseoperatormeta.py @@ -40,6 +40,11 @@ def execute(self, context: Context) -> Any: return f"Hello {self.owner}!" +class ExtendedHelloWorldOperator(HelloWorldOperator): + def execute(self, context: Context) -> Any: + return super().execute(context) + + class TestExecutorSafeguard: def setup_method(self): ExecutorSafeguard.test_mode = False @@ -49,12 +54,29 @@ def teardown_method(self, method): @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.db_test - def test_executor_when_classic_operator_called_from_dag(self, dag_maker): + @patch.object(HelloWorldOperator, "log") + def test_executor_when_classic_operator_called_from_dag(self, mock_log, dag_maker): with dag_maker() as dag: HelloWorldOperator(task_id="hello_operator") dag_run = dag.test() assert dag_run.state == DagRunState.SUCCESS + mock_log.warning.assert_not_called() + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.db_test + @patch.object(HelloWorldOperator, "log") + def test_executor_when_extended_classic_operator_called_from_dag( + self, + mock_log, + dag_maker, + ): + with dag_maker() as dag: + ExtendedHelloWorldOperator(task_id="hello_operator") + + dag_run = dag.test() + assert dag_run.state == DagRunState.SUCCESS + mock_log.warning.assert_not_called() @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize(