diff --git a/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py b/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py index 2b85825d8c6ee..b5f8a93f20daa 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py +++ b/providers/openlineage/src/airflow/providers/openlineage/extractors/base.py @@ -29,14 +29,16 @@ from openlineage.client.facet import BaseFacet as BaseFacet_V1 from openlineage.client.facet_v2 import JobFacet, RunFacet -from airflow.providers.openlineage.utils.utils import AIRFLOW_V_2_10_PLUS from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.state import TaskInstanceState # this is not to break static checks compatibility with v1 OpenLineage facet classes DatasetSubclass = TypeVar("DatasetSubclass", bound=OLDataset) BaseFacetSubclass = TypeVar("BaseFacetSubclass", bound=Union[BaseFacet_V1, RunFacet, JobFacet]) +OL_METHOD_NAME_START = "get_openlineage_facets_on_start" +OL_METHOD_NAME_COMPLETE = "get_openlineage_facets_on_complete" +OL_METHOD_NAME_FAIL = "get_openlineage_facets_on_failure" + @define class OperatorLineage(Generic[DatasetSubclass, BaseFacetSubclass]): @@ -81,6 +83,9 @@ def extract(self) -> OperatorLineage | None: def extract_on_complete(self, task_instance) -> OperatorLineage | None: return self.extract() + def extract_on_failure(self, task_instance) -> OperatorLineage | None: + return self.extract() + class DefaultExtractor(BaseExtractor): """Extractor that uses `get_openlineage_facets_on_start/complete/failure` methods.""" @@ -96,46 +101,41 @@ def get_operator_classnames(cls) -> list[str]: return [] def _execute_extraction(self) -> OperatorLineage | None: - # OpenLineage methods are optional - if there's no method, return None - try: - self.log.debug( - "Trying to execute `get_openlineage_facets_on_start` for %s.", self.operator.task_type - ) - return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore - except ImportError: - self.log.error( - "OpenLineage provider method failed to import OpenLineage integration. " - "This should not happen. Please report this bug to developers." - ) - return None - except AttributeError: + method = getattr(self.operator, OL_METHOD_NAME_START, None) + if callable(method): self.log.debug( - "Operator %s does not have the get_openlineage_facets_on_start method.", - self.operator.task_type, + "Trying to execute '%s' method of '%s'.", OL_METHOD_NAME_START, self.operator.task_type ) - return OperatorLineage() + return self._get_openlineage_facets(method) + self.log.debug( + "Operator '%s' does not have '%s' method.", self.operator.task_type, OL_METHOD_NAME_START + ) + return OperatorLineage() def extract_on_complete(self, task_instance) -> OperatorLineage | None: - failed_states = [TaskInstanceState.FAILED, TaskInstanceState.UP_FOR_RETRY] - if not AIRFLOW_V_2_10_PLUS: # todo: remove when min airflow version >= 2.10.0 - # Before fix (#41053) implemented in Airflow 2.10 TaskInstance's state was still RUNNING when - # being passed to listener's on_failure method. Since `extract_on_complete()` is only called - # after task completion, RUNNING state means that we are dealing with FAILED task in < 2.10 - failed_states = [TaskInstanceState.RUNNING] - - if task_instance.state in failed_states: - on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None) - if on_failed and callable(on_failed): - self.log.debug( - "Executing `get_openlineage_facets_on_failure` for %s.", self.operator.task_type - ) - return self._get_openlineage_facets(on_failed, task_instance) - on_complete = getattr(self.operator, "get_openlineage_facets_on_complete", None) - if on_complete and callable(on_complete): - self.log.debug("Executing `get_openlineage_facets_on_complete` for %s.", self.operator.task_type) - return self._get_openlineage_facets(on_complete, task_instance) + method = getattr(self.operator, OL_METHOD_NAME_COMPLETE, None) + if callable(method): + self.log.debug( + "Trying to execute '%s' method of '%s'.", OL_METHOD_NAME_COMPLETE, self.operator.task_type + ) + return self._get_openlineage_facets(method, task_instance) + self.log.debug( + "Operator '%s' does not have '%s' method.", self.operator.task_type, OL_METHOD_NAME_COMPLETE + ) return self.extract() + def extract_on_failure(self, task_instance) -> OperatorLineage | None: + method = getattr(self.operator, OL_METHOD_NAME_FAIL, None) + if callable(method): + self.log.debug( + "Trying to execute '%s' method of '%s'.", OL_METHOD_NAME_FAIL, self.operator.task_type + ) + return self._get_openlineage_facets(method, task_instance) + self.log.debug( + "Operator '%s' does not have '%s' method.", self.operator.task_type, OL_METHOD_NAME_FAIL + ) + return self.extract_on_complete(task_instance) + def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | None: try: facets: OperatorLineage = get_facets_method(*args) @@ -153,5 +153,5 @@ def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | "This should not happen." ) except Exception: - self.log.warning("OpenLineage provider method failed to extract data from provider. ") + self.log.warning("OpenLineage provider method failed to extract data from provider.") return None diff --git a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py index f07014885eaff..964616382f095 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py +++ b/providers/openlineage/src/airflow/providers/openlineage/extractors/manager.py @@ -24,7 +24,11 @@ ) from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors import BaseExtractor, OperatorLineage -from airflow.providers.openlineage.extractors.base import DefaultExtractor +from airflow.providers.openlineage.extractors.base import ( + OL_METHOD_NAME_COMPLETE, + OL_METHOD_NAME_START, + DefaultExtractor, +) from airflow.providers.openlineage.extractors.bash import BashExtractor from airflow.providers.openlineage.extractors.python import PythonExtractor from airflow.providers.openlineage.utils.utils import ( @@ -32,6 +36,7 @@ try_import_from_string, ) from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: from openlineage.client.event_v2 import Dataset @@ -87,7 +92,9 @@ def __init__(self): def add_extractor(self, operator_class: str, extractor: type[BaseExtractor]): self.extractors[operator_class] = extractor - def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=None) -> OperatorLineage: + def extract_metadata( + self, dagrun, task, task_instance_state: TaskInstanceState, task_instance=None + ) -> OperatorLineage: extractor = self._get_extractor(task) task_info = ( f"task_type={task.task_type} " @@ -104,10 +111,15 @@ def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=N extractor.__class__.__name__, str(task_info), ) - if complete: - task_metadata = extractor.extract_on_complete(task_instance) - else: + if task_instance_state == TaskInstanceState.RUNNING: task_metadata = extractor.extract() + elif task_instance_state == TaskInstanceState.FAILED: + if callable(getattr(extractor, "extract_on_failure", None)): + task_metadata = extractor.extract_on_failure(task_instance) + else: + task_metadata = extractor.extract_on_complete(task_instance) + else: + task_metadata = extractor.extract_on_complete(task_instance) self.log.debug( "Found task metadata for operation %s: %s", @@ -155,13 +167,9 @@ def get_extractor_class(self, task: Operator) -> type[BaseExtractor] | None: return self.extractors[task.task_type] def method_exists(method_name): - method = getattr(task, method_name, None) - if method: - return callable(method) + return callable(getattr(task, method_name, None)) - if method_exists("get_openlineage_facets_on_start") or method_exists( - "get_openlineage_facets_on_complete" - ): + if method_exists(OL_METHOD_NAME_START) or method_exists(OL_METHOD_NAME_COMPLETE): return self.default_extractor return None @@ -191,7 +199,8 @@ def extract_inlets_and_outlets( if d: task_metadata.outputs.append(d) - def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None: + @staticmethod + def get_hook_lineage() -> tuple[list[Dataset], list[Dataset]] | None: try: from airflow.providers.common.compat.lineage.hook import ( get_hook_lineage_collector, diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py index 8350b2a0d517c..a7c7d5e1f9f8b 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/adapter.py @@ -85,7 +85,7 @@ def get_or_create_openlineage_client(self) -> OpenLineageClient: if config: self.log.debug( "OpenLineage configuration found. Transport type: `%s`", - config.get("type", "no type provided"), + config.get("transport", {}).get("type", "no type provided"), ) self._client = OpenLineageClient(config=config) # type: ignore[call-arg] else: diff --git a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py index 3af06538ce7fe..b46c65de64af2 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py +++ b/providers/openlineage/src/airflow/providers/openlineage/plugins/listener.py @@ -200,7 +200,9 @@ def on_running(): operator_name = task.task_type.lower() with Stats.timer(f"ol.extract.{event_type}.{operator_name}"): - task_metadata = self.extractor_manager.extract_metadata(dagrun, task) + task_metadata = self.extractor_manager.extract_metadata( + dagrun=dagrun, task=task, task_instance_state=TaskInstanceState.RUNNING + ) redacted_event = self.adapter.start_task( run_id=task_uuid, @@ -303,7 +305,10 @@ def on_success(): with Stats.timer(f"ol.extract.{event_type}.{operator_name}"): task_metadata = self.extractor_manager.extract_metadata( - dagrun, task, complete=True, task_instance=task_instance + dagrun=dagrun, + task=task, + task_instance_state=TaskInstanceState.SUCCESS, + task_instance=task_instance, ) redacted_event = self.adapter.complete_task( @@ -424,7 +429,10 @@ def on_failure(): with Stats.timer(f"ol.extract.{event_type}.{operator_name}"): task_metadata = self.extractor_manager.extract_metadata( - dagrun, task, complete=True, task_instance=task_instance + dagrun=dagrun, + task=task, + task_instance_state=TaskInstanceState.FAILED, + task_instance=task_instance, ) redacted_event = self.adapter.fail_task( diff --git a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py index c85f75b375112..a120537a4144a 100644 --- a/providers/openlineage/tests/unit/openlineage/extractors/test_base.py +++ b/providers/openlineage/tests/unit/openlineage/extractors/test_base.py @@ -56,16 +56,43 @@ class CompleteRunFacet(JobFacet): finished: bool = field(default=False) +@define +class FailRunFacet(JobFacet): + failed: bool = field(default=False) + + FINISHED_FACETS: dict[str, JobFacet] = {"complete": CompleteRunFacet(True)} +FAILED_FACETS: dict[str, JobFacet] = {"failure": FailRunFacet(True)} class ExampleExtractor(BaseExtractor): @classmethod def get_operator_classnames(cls): - return ["ExampleOperator"] + return ["OperatorWithoutFailure"] + + +class OperatorWithoutFailure(BaseOperator): + def execute(self, context) -> Any: + pass + + def get_openlineage_facets_on_start(self) -> OperatorLineage: + return OperatorLineage( + inputs=INPUTS, + outputs=OUTPUTS, + run_facets=RUN_FACETS, + job_facets=JOB_FACETS, + ) + + def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage: + return OperatorLineage( + inputs=INPUTS, + outputs=OUTPUTS, + run_facets=RUN_FACETS, + job_facets=FINISHED_FACETS, + ) -class ExampleOperator(BaseOperator): +class OperatorWithAllOlMethods(BaseOperator): def execute(self, context) -> Any: pass @@ -85,6 +112,14 @@ def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage: job_facets=FINISHED_FACETS, ) + def get_openlineage_facets_on_failure(self, task_instance) -> OperatorLineage: + return OperatorLineage( + inputs=INPUTS, + outputs=OUTPUTS, + run_facets=RUN_FACETS, + job_facets=FAILED_FACETS, + ) + class OperatorWithoutComplete(BaseOperator): def execute(self, context) -> Any: @@ -162,14 +197,14 @@ def execute(self, context) -> Any: def test_default_extraction(): - extractor = ExtractorManager().get_extractor_class(ExampleOperator) + extractor = ExtractorManager().get_extractor_class(OperatorWithoutFailure) assert extractor is DefaultExtractor - metadata = extractor(ExampleOperator(task_id="test")).extract() + metadata = extractor(OperatorWithoutFailure(task_id="test")).extract() task_instance = mock.MagicMock() - metadata_on_complete = extractor(ExampleOperator(task_id="test")).extract_on_complete( + metadata_on_complete = extractor(OperatorWithoutFailure(task_id="test")).extract_on_complete( task_instance=task_instance ) @@ -235,50 +270,59 @@ def test_extraction_without_on_start(): @pytest.mark.parametrize( - "task_state, is_airflow_2_10_or_higher, should_call_on_failure", + "operator_class, task_state, expected_job_facets", ( - # Airflow >= 2.10 - (TaskInstanceState.FAILED, True, True), - (TaskInstanceState.UP_FOR_RETRY, True, True), - (TaskInstanceState.RUNNING, True, False), - (TaskInstanceState.SUCCESS, True, False), - # Airflow < 2.10 - (TaskInstanceState.RUNNING, False, True), - (TaskInstanceState.SUCCESS, False, False), - (TaskInstanceState.FAILED, False, False), # should never happen, fixed in #41053 - (TaskInstanceState.UP_FOR_RETRY, False, False), # should never happen, fixed in #41053 + (OperatorWithAllOlMethods, TaskInstanceState.FAILED, FAILED_FACETS), + (OperatorWithAllOlMethods, TaskInstanceState.RUNNING, JOB_FACETS), + (OperatorWithAllOlMethods, TaskInstanceState.SUCCESS, FINISHED_FACETS), + (OperatorWithAllOlMethods, TaskInstanceState.UP_FOR_RETRY, FINISHED_FACETS), # Should never happen + (OperatorWithAllOlMethods, None, FINISHED_FACETS), # Should never happen + (OperatorWithoutFailure, TaskInstanceState.FAILED, FINISHED_FACETS), + (OperatorWithoutFailure, TaskInstanceState.RUNNING, JOB_FACETS), + (OperatorWithoutFailure, TaskInstanceState.SUCCESS, FINISHED_FACETS), + (OperatorWithoutFailure, TaskInstanceState.UP_FOR_RETRY, FINISHED_FACETS), # Should never happen + (OperatorWithoutFailure, None, FINISHED_FACETS), # Should never happen + (OperatorWithoutStart, TaskInstanceState.FAILED, FINISHED_FACETS), + (OperatorWithoutStart, TaskInstanceState.RUNNING, {}), + (OperatorWithoutStart, TaskInstanceState.SUCCESS, FINISHED_FACETS), + (OperatorWithoutStart, TaskInstanceState.UP_FOR_RETRY, FINISHED_FACETS), # Should never happen + (OperatorWithoutStart, None, FINISHED_FACETS), # Should never happen + (OperatorWithoutComplete, TaskInstanceState.FAILED, JOB_FACETS), + (OperatorWithoutComplete, TaskInstanceState.RUNNING, JOB_FACETS), + (OperatorWithoutComplete, TaskInstanceState.SUCCESS, JOB_FACETS), + (OperatorWithoutComplete, TaskInstanceState.UP_FOR_RETRY, JOB_FACETS), # Should never happen + (OperatorWithoutComplete, None, JOB_FACETS), # Should never happen ), ) -def test_extract_on_failure(task_state, is_airflow_2_10_or_higher, should_call_on_failure): - task_instance = mock.Mock(state=task_state) - operator = mock.Mock() - operator.get_openlineage_facets_on_failure = mock.Mock( - return_value=OperatorLineage(run_facets={"failed": True}) +def test_extractor_manager_calls_appropriate_extractor_method( + operator_class, task_state, expected_job_facets +): + extractor_manager = ExtractorManager() + + ti = mock.MagicMock() + + metadata = extractor_manager.extract_metadata( + dagrun=mock.MagicMock(run_id="dagrun_run_id"), + task=operator_class(task_id="task_id"), + task_instance_state=task_state, + task_instance=ti, ) - operator.get_openlineage_facets_on_complete = mock.Mock(return_value=None) - - extractor = DefaultExtractor(operator=operator) - with mock.patch( - "airflow.providers.openlineage.extractors.base.AIRFLOW_V_2_10_PLUS", is_airflow_2_10_or_higher - ): - result = extractor.extract_on_complete(task_instance) - - if should_call_on_failure: - operator.get_openlineage_facets_on_failure.assert_called_once_with(task_instance) - operator.get_openlineage_facets_on_complete.assert_not_called() - assert isinstance(result, OperatorLineage) - assert result.run_facets == {"failed": True} - else: - operator.get_openlineage_facets_on_failure.assert_not_called() - operator.get_openlineage_facets_on_complete.assert_called_once_with(task_instance) - assert result is None + assert metadata.job_facets == expected_job_facets + if not expected_job_facets: # Empty OperatorLineage() is expected + assert not metadata.inputs + assert not metadata.outputs + assert not metadata.run_facets + else: + assert metadata.inputs == INPUTS + assert metadata.outputs == OUTPUTS + assert metadata.run_facets == RUN_FACETS @mock.patch("airflow.providers.openlineage.conf.custom_extractors") def test_extractors_env_var(custom_extractors): custom_extractors.return_value = {"unit.openlineage.extractors.test_base.ExampleExtractor"} - extractor = ExtractorManager().get_extractor_class(ExampleOperator(task_id="example")) + extractor = ExtractorManager().get_extractor_class(OperatorWithoutFailure(task_id="example")) assert extractor is ExampleExtractor @@ -292,7 +336,7 @@ def test_does_not_use_default_extractor_when_no_get_openlineage_facets(): assert extractor_class is None -def test_does_not_use_default_extractor_when_explicite_extractor(): +def test_does_not_use_default_extractor_when_explicit_extractor(): extractor_class = ExtractorManager().get_extractor_class( PythonOperator(task_id="c", python_callable=lambda: 7) ) @@ -316,6 +360,4 @@ def test_default_extractor_uses_wrong_operatorlineage_class(): operator = OperatorWrongOperatorLineageClass(task_id="task_id") # If extractor returns lineage class that can't be changed into OperatorLineage, just return # empty OperatorLineage - assert ( - ExtractorManager().extract_metadata(mock.MagicMock(), operator, complete=False) == OperatorLineage() - ) + assert ExtractorManager().extract_metadata(mock.MagicMock(), operator, None) == OperatorLineage() diff --git a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py index 256bc133f1a41..04739f22633c7 100644 --- a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py +++ b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py @@ -293,7 +293,9 @@ def test_extractor_manager_uses_hook_level_lineage(hook_lineage_collector): hook_lineage_collector.add_input_asset(None, uri="s3://bucket/input_key") hook_lineage_collector.add_output_asset(None, uri="s3://bucket/output_key") extractor_manager = ExtractorManager() - metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, complete=True, task_instance=ti) + metadata = extractor_manager.extract_metadata( + dagrun=dagrun, task=task, task_instance_state=None, task_instance=ti + ) assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket", name="input_key")] assert metadata.outputs == [OpenLineageDataset(namespace="s3://bucket", name="output_key")] @@ -318,7 +320,9 @@ def get_openlineage_facets_on_start(self): hook_lineage_collector.add_input_asset(None, uri="s3://bucket/input_key") extractor_manager = ExtractorManager() - metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, complete=True, task_instance=ti) + metadata = extractor_manager.extract_metadata( + dagrun=dagrun, task=task, task_instance_state=None, task_instance=ti + ) # s3://bucket/input_key not here - use data from operator assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket", name="proper_input_key")]