Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@
from openlineage.client.facet import BaseFacet as BaseFacet_V1
from openlineage.client.facet_v2 import JobFacet, RunFacet

from airflow.providers.openlineage.utils.utils import AIRFLOW_V_2_10_PLUS
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState

# this is not to break static checks compatibility with v1 OpenLineage facet classes
DatasetSubclass = TypeVar("DatasetSubclass", bound=OLDataset)
BaseFacetSubclass = TypeVar("BaseFacetSubclass", bound=Union[BaseFacet_V1, RunFacet, JobFacet])

OL_METHOD_NAME_START = "get_openlineage_facets_on_start"
OL_METHOD_NAME_COMPLETE = "get_openlineage_facets_on_complete"
OL_METHOD_NAME_FAIL = "get_openlineage_facets_on_failure"


@define
class OperatorLineage(Generic[DatasetSubclass, BaseFacetSubclass]):
Expand Down Expand Up @@ -81,6 +83,9 @@ def extract(self) -> OperatorLineage | None:
def extract_on_complete(self, task_instance) -> OperatorLineage | None:
return self.extract()

def extract_on_failure(self, task_instance) -> OperatorLineage | None:
return self.extract()


class DefaultExtractor(BaseExtractor):
"""Extractor that uses `get_openlineage_facets_on_start/complete/failure` methods."""
Expand All @@ -96,46 +101,41 @@ def get_operator_classnames(cls) -> list[str]:
return []

def _execute_extraction(self) -> OperatorLineage | None:
# OpenLineage methods are optional - if there's no method, return None
try:
self.log.debug(
"Trying to execute `get_openlineage_facets_on_start` for %s.", self.operator.task_type
)
return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore
except ImportError:
self.log.error(
"OpenLineage provider method failed to import OpenLineage integration. "
"This should not happen. Please report this bug to developers."
)
return None
except AttributeError:
method = getattr(self.operator, OL_METHOD_NAME_START, None)
if callable(method):
self.log.debug(
"Operator %s does not have the get_openlineage_facets_on_start method.",
self.operator.task_type,
"Trying to execute '%s' method of '%s'.", OL_METHOD_NAME_START, self.operator.task_type
)
return OperatorLineage()
return self._get_openlineage_facets(method)
self.log.debug(
"Operator '%s' does not have '%s' method.", self.operator.task_type, OL_METHOD_NAME_START
)
return OperatorLineage()

def extract_on_complete(self, task_instance) -> OperatorLineage | None:
failed_states = [TaskInstanceState.FAILED, TaskInstanceState.UP_FOR_RETRY]
if not AIRFLOW_V_2_10_PLUS: # todo: remove when min airflow version >= 2.10.0
# Before fix (#41053) implemented in Airflow 2.10 TaskInstance's state was still RUNNING when
# being passed to listener's on_failure method. Since `extract_on_complete()` is only called
# after task completion, RUNNING state means that we are dealing with FAILED task in < 2.10
failed_states = [TaskInstanceState.RUNNING]

if task_instance.state in failed_states:
on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None)
if on_failed and callable(on_failed):
self.log.debug(
"Executing `get_openlineage_facets_on_failure` for %s.", self.operator.task_type
)
return self._get_openlineage_facets(on_failed, task_instance)
on_complete = getattr(self.operator, "get_openlineage_facets_on_complete", None)
if on_complete and callable(on_complete):
self.log.debug("Executing `get_openlineage_facets_on_complete` for %s.", self.operator.task_type)
return self._get_openlineage_facets(on_complete, task_instance)
method = getattr(self.operator, OL_METHOD_NAME_COMPLETE, None)
if callable(method):
self.log.debug(
"Trying to execute '%s' method of '%s'.", OL_METHOD_NAME_COMPLETE, self.operator.task_type
)
return self._get_openlineage_facets(method, task_instance)
self.log.debug(
"Operator '%s' does not have '%s' method.", self.operator.task_type, OL_METHOD_NAME_COMPLETE
)
return self.extract()

def extract_on_failure(self, task_instance) -> OperatorLineage | None:
method = getattr(self.operator, OL_METHOD_NAME_FAIL, None)
if callable(method):
self.log.debug(
"Trying to execute '%s' method of '%s'.", OL_METHOD_NAME_FAIL, self.operator.task_type
)
return self._get_openlineage_facets(method, task_instance)
self.log.debug(
"Operator '%s' does not have '%s' method.", self.operator.task_type, OL_METHOD_NAME_FAIL
)
return self.extract_on_complete(task_instance)

def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage | None:
try:
facets: OperatorLineage = get_facets_method(*args)
Expand All @@ -153,5 +153,5 @@ def _get_openlineage_facets(self, get_facets_method, *args) -> OperatorLineage |
"This should not happen."
)
except Exception:
self.log.warning("OpenLineage provider method failed to extract data from provider. ")
self.log.warning("OpenLineage provider method failed to extract data from provider.")
return None
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@
)
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import BaseExtractor, OperatorLineage
from airflow.providers.openlineage.extractors.base import DefaultExtractor
from airflow.providers.openlineage.extractors.base import (
OL_METHOD_NAME_COMPLETE,
OL_METHOD_NAME_START,
DefaultExtractor,
)
from airflow.providers.openlineage.extractors.bash import BashExtractor
from airflow.providers.openlineage.extractors.python import PythonExtractor
from airflow.providers.openlineage.utils.utils import (
get_unknown_source_attribute_run_facet,
try_import_from_string,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from openlineage.client.event_v2 import Dataset
Expand Down Expand Up @@ -87,7 +92,9 @@ def __init__(self):
def add_extractor(self, operator_class: str, extractor: type[BaseExtractor]):
self.extractors[operator_class] = extractor

def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=None) -> OperatorLineage:
def extract_metadata(
self, dagrun, task, task_instance_state: TaskInstanceState, task_instance=None
) -> OperatorLineage:
extractor = self._get_extractor(task)
task_info = (
f"task_type={task.task_type} "
Expand All @@ -104,10 +111,15 @@ def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=N
extractor.__class__.__name__,
str(task_info),
)
if complete:
task_metadata = extractor.extract_on_complete(task_instance)
else:
if task_instance_state == TaskInstanceState.RUNNING:
task_metadata = extractor.extract()
elif task_instance_state == TaskInstanceState.FAILED:
if callable(getattr(extractor, "extract_on_failure", None)):
task_metadata = extractor.extract_on_failure(task_instance)
else:
task_metadata = extractor.extract_on_complete(task_instance)
else:
task_metadata = extractor.extract_on_complete(task_instance)

self.log.debug(
"Found task metadata for operation %s: %s",
Expand Down Expand Up @@ -155,13 +167,9 @@ def get_extractor_class(self, task: Operator) -> type[BaseExtractor] | None:
return self.extractors[task.task_type]

def method_exists(method_name):
method = getattr(task, method_name, None)
if method:
return callable(method)
return callable(getattr(task, method_name, None))

if method_exists("get_openlineage_facets_on_start") or method_exists(
"get_openlineage_facets_on_complete"
):
if method_exists(OL_METHOD_NAME_START) or method_exists(OL_METHOD_NAME_COMPLETE):
return self.default_extractor
return None

Expand Down Expand Up @@ -191,7 +199,8 @@ def extract_inlets_and_outlets(
if d:
task_metadata.outputs.append(d)

def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None:
@staticmethod
def get_hook_lineage() -> tuple[list[Dataset], list[Dataset]] | None:
try:
from airflow.providers.common.compat.lineage.hook import (
get_hook_lineage_collector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_or_create_openlineage_client(self) -> OpenLineageClient:
if config:
self.log.debug(
"OpenLineage configuration found. Transport type: `%s`",
config.get("type", "no type provided"),
config.get("transport", {}).get("type", "no type provided"),
)
self._client = OpenLineageClient(config=config) # type: ignore[call-arg]
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def on_running():
operator_name = task.task_type.lower()

with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
task_metadata = self.extractor_manager.extract_metadata(dagrun, task)
task_metadata = self.extractor_manager.extract_metadata(
dagrun=dagrun, task=task, task_instance_state=TaskInstanceState.RUNNING
)

redacted_event = self.adapter.start_task(
run_id=task_uuid,
Expand Down Expand Up @@ -303,7 +305,10 @@ def on_success():

with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
task_metadata = self.extractor_manager.extract_metadata(
dagrun, task, complete=True, task_instance=task_instance
dagrun=dagrun,
task=task,
task_instance_state=TaskInstanceState.SUCCESS,
task_instance=task_instance,
)

redacted_event = self.adapter.complete_task(
Expand Down Expand Up @@ -424,7 +429,10 @@ def on_failure():

with Stats.timer(f"ol.extract.{event_type}.{operator_name}"):
task_metadata = self.extractor_manager.extract_metadata(
dagrun, task, complete=True, task_instance=task_instance
dagrun=dagrun,
task=task,
task_instance_state=TaskInstanceState.FAILED,
task_instance=task_instance,
)

redacted_event = self.adapter.fail_task(
Expand Down
Loading