Skip to content

Commit

Permalink
FIX: Don't raise a warning in ExecutorSafeguard when execute is calle…
Browse files Browse the repository at this point in the history
…d from an extended operator (apache#42849)

* refactor: Don't raise a warning when execute is called from an extended operator, as this should always be allowed.

* refactored: Fixed import of test_utils in test_dag_run

---------

Co-authored-by: David Blain <david.blain@infrabel.be>
  • Loading branch information
2 people authored and PaulKobow7536 committed Oct 24, 2024
1 parent 2822335 commit feb8115
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
11 changes: 10 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import warnings
from datetime import datetime, timedelta
from functools import total_ordering, wraps
from threading import local
from types import FunctionType
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -392,14 +393,22 @@ class ExecutorSafeguard:
"""

test_mode = conf.getboolean("core", "unit_test_mode")
_sentinel = local()
_sentinel.callers = {}

@classmethod
def decorator(cls, func):
@wraps(func)
def wrapper(self, *args, **kwargs):
from airflow.decorators.base import DecoratedOperator

sentinel = kwargs.pop(f"{self.__class__.__name__}__sentinel", None)
sentinel_key = f"{self.__class__.__name__}__sentinel"
sentinel = kwargs.pop(sentinel_key, None)

if sentinel:
cls._sentinel.callers[sentinel_key] = sentinel
else:
sentinel = cls._sentinel.callers.pop(f"{func.__qualname__.split('.')[0]}__sentinel", None)

if not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator):
message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside TaskInstance!"
Expand Down
24 changes: 23 additions & 1 deletion tests/models/test_baseoperatormeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def execute(self, context: Context) -> Any:
return f"Hello {self.owner}!"


class ExtendedHelloWorldOperator(HelloWorldOperator):
def execute(self, context: Context) -> Any:
return super().execute(context)


class TestExecutorSafeguard:
def setup_method(self):
ExecutorSafeguard.test_mode = False
Expand All @@ -49,12 +54,29 @@ def teardown_method(self, method):

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@pytest.mark.db_test
def test_executor_when_classic_operator_called_from_dag(self, dag_maker):
@patch.object(HelloWorldOperator, "log")
def test_executor_when_classic_operator_called_from_dag(self, mock_log, dag_maker):
with dag_maker() as dag:
HelloWorldOperator(task_id="hello_operator")

dag_run = dag.test()
assert dag_run.state == DagRunState.SUCCESS
mock_log.warning.assert_not_called()

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@pytest.mark.db_test
@patch.object(HelloWorldOperator, "log")
def test_executor_when_extended_classic_operator_called_from_dag(
self,
mock_log,
dag_maker,
):
with dag_maker() as dag:
ExtendedHelloWorldOperator(task_id="hello_operator")

dag_run = dag.test()
assert dag_run.state == DagRunState.SUCCESS
mock_log.warning.assert_not_called()

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@pytest.mark.parametrize(
Expand Down

0 comments on commit feb8115

Please sign in to comment.