Skip to content

Commit

Permalink
openlineage: adjust default extractor's on_failure detection for airf…
Browse files Browse the repository at this point in the history
…low 2.10 fix

Signed-off-by: Kacper Muda <mudakacper@gmail.com>
  • Loading branch information
kacpermuda committed Aug 1, 2024
1 parent b28e8bf commit 79923a3
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 9 deletions.
10 changes: 9 additions & 1 deletion airflow/providers/openlineage/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from openlineage.client.facet import BaseFacet as BaseFacet_V1
from openlineage.client.facet_v2 import JobFacet, RunFacet

from airflow.providers.openlineage.utils.utils import IS_AIRFLOW_2_10_OR_HIGHER
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState

Expand Down Expand Up @@ -115,7 +116,14 @@ def _execute_extraction(self) -> OperatorLineage | None:
return None

def extract_on_complete(self, task_instance) -> OperatorLineage | None:
if task_instance.state == TaskInstanceState.FAILED:
failed_states = [TaskInstanceState.FAILED, TaskInstanceState.UP_FOR_RETRY]
if not IS_AIRFLOW_2_10_OR_HIGHER: # todo: remove when min airflow version >= 2.10.0
# Before fix (#41053) implemented in Airflow 2.10 TaskInstance's state was still RUNNING when
# being passed to listener's on_failure method. Since `extract_on_complete()` is only called
# after task completion, RUNNING state means that we are dealing with FAILED task in < 2.10
failed_states = [TaskInstanceState.RUNNING]

if task_instance.state in failed_states:
on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None)
if on_failed and callable(on_failed):
self.log.debug(
Expand Down
9 changes: 4 additions & 5 deletions airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@

import psutil
from openlineage.client.serde import Serde
from packaging.version import Version
from setproctitle import getproctitle, setproctitle

from airflow import __version__ as AIRFLOW_VERSION, settings
from airflow import settings
from airflow.listeners import hookimpl
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import ExtractorManager
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, RunState
from airflow.providers.openlineage.utils.utils import (
IS_AIRFLOW_2_10_OR_HIGHER,
get_airflow_job_facet,
get_airflow_mapped_task_facet,
get_airflow_run_facet,
Expand All @@ -53,12 +53,11 @@
from airflow.models import DagRun, TaskInstance

_openlineage_listener: OpenLineageListener | None = None
_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")


def _get_try_number_success(val):
# todo: remove when min airflow version >= 2.10.0
if _IS_AIRFLOW_2_10_OR_HIGHER:
if IS_AIRFLOW_2_10_OR_HIGHER:
return val.try_number
return val.try_number - 1

Expand Down Expand Up @@ -247,7 +246,7 @@ def on_success():

self._execute(on_success, "on_success", use_fork=True)

if _IS_AIRFLOW_2_10_OR_HIGHER:
if IS_AIRFLOW_2_10_OR_HIGHER:

@hookimpl
def on_task_instance_failed(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

log = logging.getLogger(__name__)
_NOMINAL_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")
IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")


def try_import_from_string(string: str) -> Any:
Expand Down Expand Up @@ -663,7 +663,7 @@ def normalize_sql(sql: str | Iterable[str]):

def should_use_external_connection(hook) -> bool:
# If we're at Airflow 2.10, the execution is process-isolated, so we can safely run those again.
if not _IS_AIRFLOW_2_10_OR_HIGHER:
if not IS_AIRFLOW_2_10_OR_HIGHER:
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook", "RedshiftSQLHook"]
return True

Expand Down
2 changes: 1 addition & 1 deletion tests/providers/amazon/aws/operators/test_redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class TestRedshiftSQLOpenLineage:
"airflow.providers.amazon.aws.hooks.redshift_sql._IS_AIRFLOW_2_10_OR_HIGHER",
new_callable=PropertyMock,
)
@patch("airflow.providers.openlineage.utils.utils._IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock)
@patch("airflow.providers.openlineage.utils.utils.IS_AIRFLOW_2_10_OR_HIGHER", new_callable=PropertyMock)
@patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
def test_execute_openlineage_events(
self,
Expand Down
42 changes: 42 additions & 0 deletions tests/providers/openlineage/extractors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from openlineage.client.facet_v2 import BaseFacet, JobFacet, parent_run, sql_job

from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstance import TaskInstanceState
from airflow.operators.python import PythonOperator
from airflow.providers.openlineage.extractors.base import (
BaseExtractor,
Expand Down Expand Up @@ -233,6 +234,47 @@ def test_extraction_without_on_start():
)


@pytest.mark.parametrize(
"task_state, is_airflow_2_10_or_higher, should_call_on_failure",
(
# Airflow >= 2.10
(TaskInstanceState.FAILED, True, True),
(TaskInstanceState.UP_FOR_RETRY, True, True),
(TaskInstanceState.RUNNING, True, False),
(TaskInstanceState.SUCCESS, True, False),
# Airflow < 2.10
(TaskInstanceState.RUNNING, False, True),
(TaskInstanceState.SUCCESS, False, False),
(TaskInstanceState.FAILED, False, False), # should never happen, fixed in #41053
(TaskInstanceState.UP_FOR_RETRY, False, False), # should never happen, fixed in #41053
),
)
def test_extract_on_failure(task_state, is_airflow_2_10_or_higher, should_call_on_failure):
task_instance = mock.Mock(state=task_state)
operator = mock.Mock()
operator.get_openlineage_facets_on_failure = mock.Mock(
return_value=OperatorLineage(run_facets={"failed": True})
)
operator.get_openlineage_facets_on_complete = mock.Mock(return_value=None)

extractor = DefaultExtractor(operator=operator)

with mock.patch(
"airflow.providers.openlineage.extractors.base.IS_AIRFLOW_2_10_OR_HIGHER", is_airflow_2_10_or_higher
):
result = extractor.extract_on_complete(task_instance)

if should_call_on_failure:
operator.get_openlineage_facets_on_failure.assert_called_once_with(task_instance)
operator.get_openlineage_facets_on_complete.assert_not_called()
assert isinstance(result, OperatorLineage)
assert result.run_facets == {"failed": True}
else:
operator.get_openlineage_facets_on_failure.assert_not_called()
operator.get_openlineage_facets_on_complete.assert_called_once_with(task_instance)
assert result is None


@mock.patch("airflow.providers.openlineage.conf.custom_extractors")
def test_extractors_env_var(custom_extractors):
custom_extractors.return_value = {"tests.providers.openlineage.extractors.test_base.ExampleExtractor"}
Expand Down

0 comments on commit 79923a3

Please sign in to comment.