Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions airflow/providers/openlineage/extractors/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from airflow.providers.openlineage.extractors.python import PythonExtractor
from airflow.providers.openlineage.utils.utils import (
get_unknown_source_attribute_run_facet,
translate_airflow_dataset,
try_import_from_string,
)
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -90,7 +91,6 @@ def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=N
f"task_id={task.task_id} "
f"airflow_run_id={dagrun.run_id} "
)

if extractor:
# Extracting advanced metadata is only possible when extractor for particular operator
# is defined. Without it, we can't extract any input or output data.
Expand All @@ -105,14 +105,22 @@ def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=N
task_metadata = self.validate_task_metadata(task_metadata)
if task_metadata:
if (not task_metadata.inputs) and (not task_metadata.outputs):
self.extract_inlets_and_outlets(task_metadata, task.inlets, task.outlets)

if (hook_lineage := self.get_hook_lineage()) is not None:
inputs, outputs = hook_lineage
task_metadata.inputs = inputs
task_metadata.outputs = outputs
else:
self.extract_inlets_and_outlets(task_metadata, task.inlets, task.outlets)
return task_metadata

except Exception as e:
self.log.warning(
"Failed to extract metadata using found extractor %s - %s %s", extractor, e, task_info
)
elif (hook_lineage := self.get_hook_lineage()) is not None:
inputs, outputs = hook_lineage
task_metadata = OperatorLineage(inputs=inputs, outputs=outputs)
return task_metadata
else:
self.log.debug("Unable to find an extractor %s", task_info)

Expand Down Expand Up @@ -168,6 +176,30 @@ def extract_inlets_and_outlets(
if d:
task_metadata.outputs.append(d)

def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None:
try:
from airflow.lineage.hook import get_hook_lineage_collector
except ImportError:
return None

if not get_hook_lineage_collector().has_collected:
return None

return (
[
dataset
for dataset_info in get_hook_lineage_collector().collected_datasets.inputs
if (dataset := translate_airflow_dataset(dataset_info.dataset, dataset_info.context))
is not None
],
[
dataset
for dataset_info in get_hook_lineage_collector().collected_datasets.outputs
if (dataset := translate_airflow_dataset(dataset_info.dataset, dataset_info.context))
is not None
],
)

@staticmethod
def convert_to_ol_dataset_from_object_storage_uri(uri: str) -> Dataset | None:
from urllib.parse import urlparse
Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/openlineage/plugins/openlineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
lineage_parent_id,
lineage_run_id,
)
from airflow.providers.openlineage.utils.utils import IS_AIRFLOW_2_10_OR_HIGHER


class OpenLineageProviderPlugin(AirflowPlugin):
Expand All @@ -39,6 +40,10 @@ class OpenLineageProviderPlugin(AirflowPlugin):
if not conf.is_disabled():
macros = [lineage_job_namespace, lineage_job_name, lineage_run_id, lineage_parent_id]
listeners = [get_openlineage_listener()]
if IS_AIRFLOW_2_10_OR_HIGHER:
from airflow.lineage.hook import HookLineageReader

hook_lineage_readers = [HookLineageReader]
else:
macros = []
listeners = []
1 change: 1 addition & 0 deletions airflow/providers/openlineage/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ versions:
dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-sql>=1.6.0
- apache-airflow-providers-common-compat>=1.2.0
- attrs>=22.2
- openlineage-integration-common>=1.16.0
- openlineage-python>=1.16.0
Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@
},
"openlineage": {
"deps": [
"apache-airflow-providers-common-compat>=1.2.0",
"apache-airflow-providers-common-sql>=1.6.0",
"apache-airflow>=2.8.0",
"attrs>=22.2",
Expand Down
148 changes: 127 additions & 21 deletions tests/providers/openlineage/extractors/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,47 @@
# under the License.
from __future__ import annotations

import tempfile
from typing import TYPE_CHECKING, Any
from unittest.mock import MagicMock

import pytest
from openlineage.client.event_v2 import Dataset
from openlineage.client.event_v2 import Dataset as OpenLineageDataset
from openlineage.client.facet_v2 import documentation_dataset, ownership_dataset, schema_dataset

from airflow.datasets import Dataset
from airflow.io.path import ObjectStoragePath
from airflow.lineage.entities import Column, File, Table, User
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstance import TaskInstance
from airflow.operators.python import PythonOperator
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.extractors.manager import ExtractorManager
from airflow.utils.state import State
from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS

if TYPE_CHECKING:
from airflow.utils.context import Context


@pytest.mark.parametrize(
("uri", "dataset"),
(
("s3://bucket1/dir1/file1", Dataset(namespace="s3://bucket1", name="dir1/file1")),
("gs://bucket2/dir2/file2", Dataset(namespace="gs://bucket2", name="dir2/file2")),
("gcs://bucket3/dir3/file3", Dataset(namespace="gs://bucket3", name="dir3/file3")),
("hdfs://namenodehost:8020/file1", Dataset(namespace="hdfs://namenodehost:8020", name="file1")),
("hdfs://namenodehost/file2", Dataset(namespace="hdfs://namenodehost", name="file2")),
("file://localhost/etc/fstab", Dataset(namespace="file://localhost", name="etc/fstab")),
("file:///etc/fstab", Dataset(namespace="file://", name="etc/fstab")),
("https://test.com", Dataset(namespace="https", name="test.com")),
("https://test.com?param1=test1&param2=test2", Dataset(namespace="https", name="test.com")),
("s3://bucket1/dir1/file1", OpenLineageDataset(namespace="s3://bucket1", name="dir1/file1")),
("gs://bucket2/dir2/file2", OpenLineageDataset(namespace="gs://bucket2", name="dir2/file2")),
("gcs://bucket3/dir3/file3", OpenLineageDataset(namespace="gs://bucket3", name="dir3/file3")),
(
"hdfs://namenodehost:8020/file1",
OpenLineageDataset(namespace="hdfs://namenodehost:8020", name="file1"),
),
("hdfs://namenodehost/file2", OpenLineageDataset(namespace="hdfs://namenodehost", name="file2")),
("file://localhost/etc/fstab", OpenLineageDataset(namespace="file://localhost", name="etc/fstab")),
("file:///etc/fstab", OpenLineageDataset(namespace="file://", name="etc/fstab")),
("https://test.com", OpenLineageDataset(namespace="https", name="test.com")),
(
"https://test.com?param1=test1&param2=test2",
OpenLineageDataset(namespace="https", name="test.com"),
),
("file:test.csv", None),
("not_an_url", None),
),
Expand All @@ -50,21 +71,36 @@ def test_convert_to_ol_dataset_from_object_storage_uri(uri, dataset):
("obj", "dataset"),
(
(
Dataset(namespace="n1", name="f1"),
Dataset(namespace="n1", name="f1"),
OpenLineageDataset(namespace="n1", name="f1"),
OpenLineageDataset(namespace="n1", name="f1"),
),
(
File(url="s3://bucket1/dir1/file1"),
OpenLineageDataset(namespace="s3://bucket1", name="dir1/file1"),
),
(
File(url="gs://bucket2/dir2/file2"),
OpenLineageDataset(namespace="gs://bucket2", name="dir2/file2"),
),
(
File(url="gcs://bucket3/dir3/file3"),
OpenLineageDataset(namespace="gs://bucket3", name="dir3/file3"),
),
(File(url="s3://bucket1/dir1/file1"), Dataset(namespace="s3://bucket1", name="dir1/file1")),
(File(url="gs://bucket2/dir2/file2"), Dataset(namespace="gs://bucket2", name="dir2/file2")),
(File(url="gcs://bucket3/dir3/file3"), Dataset(namespace="gs://bucket3", name="dir3/file3")),
(
File(url="hdfs://namenodehost:8020/file1"),
Dataset(namespace="hdfs://namenodehost:8020", name="file1"),
OpenLineageDataset(namespace="hdfs://namenodehost:8020", name="file1"),
),
(
File(url="hdfs://namenodehost/file2"),
OpenLineageDataset(namespace="hdfs://namenodehost", name="file2"),
),
(
File(url="file://localhost/etc/fstab"),
OpenLineageDataset(namespace="file://localhost", name="etc/fstab"),
),
(File(url="hdfs://namenodehost/file2"), Dataset(namespace="hdfs://namenodehost", name="file2")),
(File(url="file://localhost/etc/fstab"), Dataset(namespace="file://localhost", name="etc/fstab")),
(File(url="file:///etc/fstab"), Dataset(namespace="file://", name="etc/fstab")),
(File(url="https://test.com"), Dataset(namespace="https", name="test.com")),
(Table(cluster="c1", database="d1", name="t1"), Dataset(namespace="c1", name="d1.t1")),
(File(url="file:///etc/fstab"), OpenLineageDataset(namespace="file://", name="etc/fstab")),
(File(url="https://test.com"), OpenLineageDataset(namespace="https", name="test.com")),
(Table(cluster="c1", database="d1", name="t1"), OpenLineageDataset(namespace="c1", name="d1.t1")),
("gs://bucket2/dir2/file2", None),
("not_an_url", None),
),
Expand Down Expand Up @@ -167,3 +203,73 @@ def test_convert_to_ol_dataset_table():
assert result.namespace == "c1"
assert result.name == "d1.t1"
assert result.facets == expected_facets


@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
def test_extractor_manager_uses_hook_level_lineage(hook_lineage_collector):
dagrun = MagicMock()
task = MagicMock()
del task.get_openlineage_facets_on_start
del task.get_openlineage_facets_on_complete
ti = MagicMock()

hook_lineage_collector.add_input_dataset(None, uri="s3://bucket/input_key")
hook_lineage_collector.add_output_dataset(None, uri="s3://bucket/output_key")
extractor_manager = ExtractorManager()
metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, complete=True, task_instance=ti)

assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket", name="input_key")]
assert metadata.outputs == [OpenLineageDataset(namespace="s3://bucket", name="output_key")]


@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
def test_extractor_manager_does_not_use_hook_level_lineage_when_operator(hook_lineage_collector):
class FakeSupportedOperator(BaseOperator):
def execute(self, context: Context) -> Any:
pass

def get_openlineage_facets_on_start(self):
return OperatorLineage(
inputs=[OpenLineageDataset(namespace="s3://bucket", name="proper_input_key")]
)

dagrun = MagicMock()
task = FakeSupportedOperator(task_id="test_task_extractor")
ti = MagicMock()
hook_lineage_collector.add_input_dataset(None, uri="s3://bucket/input_key")

extractor_manager = ExtractorManager()
metadata = extractor_manager.extract_metadata(dagrun=dagrun, task=task, complete=True, task_instance=ti)

# s3://bucket/input_key not here - use data from operator
assert metadata.inputs == [OpenLineageDataset(namespace="s3://bucket", name="proper_input_key")]
assert metadata.outputs == []


@pytest.mark.db_test
@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
def test_extractor_manager_gets_data_from_pythonoperator(session, dag_maker, hook_lineage_collector):
path = None
with tempfile.NamedTemporaryFile() as f:
path = f.name
with dag_maker():

def use_read():
storage_path = ObjectStoragePath(path)
with storage_path.open("w") as out:
out.write("test")

task = PythonOperator(task_id="test_task_extractor_pythonoperator", python_callable=use_read)

dr = dag_maker.create_dagrun()
ti = TaskInstance(task=task, run_id=dr.run_id)
ti.state = State.QUEUED
session.merge(ti)
session.commit()

ti.run()

datasets = hook_lineage_collector.collected_datasets

assert len(datasets.outputs) == 1
assert datasets.outputs[0].dataset == Dataset(uri=path)