Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openlineage: adjust DefaultExtractor's on_failure detection for airflow 2.10 fix #41094

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion airflow/providers/openlineage/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from openlineage.client.facet import BaseFacet as BaseFacet_V1
from openlineage.client.facet_v2 import JobFacet, RunFacet

from airflow.providers.openlineage.utils.utils import IS_AIRFLOW_2_10_OR_HIGHER
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState

Expand Down Expand Up @@ -115,7 +116,14 @@ def _execute_extraction(self) -> OperatorLineage | None:
return None

def extract_on_complete(self, task_instance) -> OperatorLineage | None:
if task_instance.state == TaskInstanceState.FAILED:
failed_states = [TaskInstanceState.FAILED, TaskInstanceState.UP_FOR_RETRY]
if not IS_AIRFLOW_2_10_OR_HIGHER: # todo: remove when min airflow version >= 2.10.0
# Before fix (#41053) implemented in Airflow 2.10 TaskInstance's state was still RUNNING when
# being passed to listener's on_failure method. Since `extract_on_complete()` is only called
# after task completion, RUNNING state means that we are dealing with FAILED task in < 2.10
failed_states = [TaskInstanceState.RUNNING]

if task_instance.state in failed_states:
on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None)
if on_failed and callable(on_failed):
self.log.debug(
Expand Down
9 changes: 4 additions & 5 deletions airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@

import psutil
from openlineage.client.serde import Serde
from packaging.version import Version
from setproctitle import getproctitle, setproctitle

from airflow import __version__ as AIRFLOW_VERSION, settings
from airflow import settings
from airflow.listeners import hookimpl
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import ExtractorManager
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState
from airflow.providers.openlineage.utils.utils import (
IS_AIRFLOW_2_10_OR_HIGHER,
get_airflow_job_facet,
get_airflow_mapped_task_facet,
get_airflow_run_facet,
Expand All @@ -53,12 +53,11 @@
from airflow.models import DagRun, TaskInstance

_openlineage_listener: OpenLineageListener | None = None
_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")


def _get_try_number_success(val):
# todo: remove when min airflow version >= 2.10.0
if _IS_AIRFLOW_2_10_OR_HIGHER:
if IS_AIRFLOW_2_10_OR_HIGHER:
return val.try_number
return val.try_number - 1

Expand Down Expand Up @@ -247,7 +246,7 @@ def on_success():

self._execute(on_success, "on_success", use_fork=True)

if _IS_AIRFLOW_2_10_OR_HIGHER:
if IS_AIRFLOW_2_10_OR_HIGHER:

@hookimpl
def on_task_instance_failed(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

log = logging.getLogger(__name__)
_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")
IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")


def try_import_from_string(string: str) -> Any:
Expand Down Expand Up @@ -663,7 +663,7 @@ def normalize_sql(sql: str | Iterable[str]):

def should_use_external_connection(hook) -> bool:
# If we're at Airflow 2.10, the execution is process-isolated, so we can safely run those again.
if not _IS_AIRFLOW_2_10_OR_HIGHER:
if not IS_AIRFLOW_2_10_OR_HIGHER:
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook", "RedshiftSQLHook"]
return True

Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class TestRedshiftSQLOpenLineage:
"airflow.providers.amazon.aws.hooks.redshift_sql._IS_AIRFLOW_2_10_OR_HIGHER",
new_callable=PropertyMock,
)
@patch("airflow.providers.openlineage.utils.utils._IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock)
@patch("airflow.providers.openlineage.utils.utils.IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock)
@patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
def test_execute_openlineage_events(
self,
Expand Down
42 changes: 42 additions & 0 deletions tests/providers/openlineage/extractors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from openlineage.client.facet_v2 import BaseFacet, JobFacet, parent_run, sql_job

from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstance import TaskInstanceState
from airflow.operators.python import PythonOperator
from airflow.providers.openlineage.extractors.base import (
BaseExtractor,
Expand Down Expand Up @@ -233,6 +234,47 @@ def test_extraction_without_on_start():
)


@pytest.mark.parametrize(
"task_state, is_airflow_2_10_or_higher, should_call_on_failure",
(
# Airflow >= 2.10
(TaskInstanceState.FAILED, True, True),
(TaskInstanceState.UP_FOR_RETRY, True, True),
(TaskInstanceState.RUNNING, True, False),
(TaskInstanceState.SUCCESS, True, False),
# Airflow < 2.10
(TaskInstanceState.RUNNING, False, True),
(TaskInstanceState.SUCCESS, False, False),
(TaskInstanceState.FAILED, False, False), # should never happen, fixed in #41053
(TaskInstanceState.UP_FOR_RETRY, False, False), # should never happen, fixed in #41053
),
)
def test_extract_on_failure(task_state, is_airflow_2_10_or_higher, should_call_on_failure):
task_instance = mock.Mock(state=task_state)
operator = mock.Mock()
operator.get_openlineage_facets_on_failure = mock.Mock(
return_value=OperatorLineage(run_facets={"failed": True})
)
operator.get_openlineage_facets_on_complete = mock.Mock(return_value=None)

extractor = DefaultExtractor(operator=operator)

with mock.patch(
"airflow.providers.openlineage.extractors.base.IS_AIRFLOW_2_10_OR_HIGHER", is_airflow_2_10_or_higher
):
result = extractor.extract_on_complete(task_instance)

if should_call_on_failure:
operator.get_openlineage_facets_on_failure.assert_called_once_with(task_instance)
operator.get_openlineage_facets_on_complete.assert_not_called()
assert isinstance(result, OperatorLineage)
assert result.run_facets == {"failed": True}
else:
operator.get_openlineage_facets_on_failure.assert_not_called()
operator.get_openlineage_facets_on_complete.assert_called_once_with(task_instance)
assert result is None


@mock.patch("airflow.providers.openlineage.conf.custom_extractors")
def test_extractors_env_var(custom_extractors):
custom_extractors.return_value = {"tests.providers.openlineage.extractors.test_base.ExampleExtractor"}
Expand Down