Skip to content

Commit

Permalink
feat(providers/openlineage): Use asset in common provider (#43111)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W authored Oct 18, 2024
1 parent b540eb0 commit 1f0bba2
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 156 deletions.
37 changes: 22 additions & 15 deletions providers/src/airflow/providers/openlineage/extractors/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

from typing import TYPE_CHECKING, Iterator

from airflow.providers.common.compat.openlineage.utils.utils import translate_airflow_asset
from airflow.providers.common.compat.openlineage.utils.utils import (
translate_airflow_asset,
)
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import BaseExtractor, OperatorLineage
from airflow.providers.openlineage.extractors.base import DefaultExtractor
Expand Down Expand Up @@ -61,7 +63,8 @@ def __init__(self):
extractor: type[BaseExtractor] | None = try_import_from_string(extractor_path)
if not extractor:
self.log.warning(
"OpenLineage is unable to import custom extractor `%s`; will ignore it.", extractor_path
"OpenLineage is unable to import custom extractor `%s`; will ignore it.",
extractor_path,
)
continue
for operator_class in extractor.get_operator_classnames():
Expand Down Expand Up @@ -95,13 +98,21 @@ def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=N
# Extracting advanced metadata is only possible when extractor for particular operator
# is defined. Without it, we can't extract any input or output data.
try:
self.log.debug("Using extractor %s %s", extractor.__class__.__name__, str(task_info))
self.log.debug(
"Using extractor %s %s",
extractor.__class__.__name__,
str(task_info),
)
if complete:
task_metadata = extractor.extract_on_complete(task_instance)
else:
task_metadata = extractor.extract()

self.log.debug("Found task metadata for operation %s: %s", task.task_id, str(task_metadata))
self.log.debug(
"Found task metadata for operation %s: %s",
task.task_id,
str(task_metadata),
)
task_metadata = self.validate_task_metadata(task_metadata)
if task_metadata:
if (not task_metadata.inputs) and (not task_metadata.outputs):
Expand All @@ -115,7 +126,10 @@ def extract_metadata(self, dagrun, task, complete: bool = False, task_instance=N

except Exception as e:
self.log.warning(
"Failed to extract metadata using found extractor %s - %s %s", extractor, e, task_info
"Failed to extract metadata using found extractor %s - %s %s",
extractor,
e,
task_info,
)
elif (hook_lineage := self.get_hook_lineage()) is not None:
inputs, outputs = hook_lineage
Expand Down Expand Up @@ -178,16 +192,9 @@ def extract_inlets_and_outlets(

def get_hook_lineage(self) -> tuple[list[Dataset], list[Dataset]] | None:
try:
from importlib.util import find_spec

if find_spec("airflow.assets"):
from airflow.lineage.hook import get_hook_lineage_collector
else:
# TODO: import from common.compat directly after common.compat providers with
# asset_compat_lineage_collector released
from airflow.providers.openlineage.utils.asset_compat_lineage_collector import (
get_hook_lineage_collector,
)
from airflow.providers.common.compat.lineage.hook import (
get_hook_lineage_collector,
)
except ImportError:
return None

Expand Down

This file was deleted.

47 changes: 35 additions & 12 deletions providers/src/airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@
from packaging.version import Version

from airflow import __version__ as AIRFLOW_VERSION
from airflow.exceptions import AirflowProviderDeprecationWarning # TODO: move this maybe to Airflow's logic?
from airflow.exceptions import (
AirflowProviderDeprecationWarning,
)

# TODO: move this maybe to Airflow's logic?
from airflow.models import DAG, BaseOperator, DagRun, MappedOperator
from airflow.providers.common.compat.assets import Asset
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.plugins.facets import (
AirflowDagRunFacet,
Expand All @@ -50,14 +55,14 @@
)
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.context import AirflowContextDeprecationWarning
from airflow.utils.log.secrets_masker import Redactable, Redacted, SecretsMasker, should_hide_value_for_key
from airflow.utils.log.secrets_masker import (
Redactable,
Redacted,
SecretsMasker,
should_hide_value_for_key,
)
from airflow.utils.module_loading import import_string

try:
from airflow.assets import Asset
except ModuleNotFoundError:
from airflow.datasets import Dataset as Asset # type: ignore[no-redef]

if TYPE_CHECKING:
from openlineage.client.event_v2 import Dataset as OpenLineageDataset
from openlineage.client.facet_v2 import RunFacet
Expand Down Expand Up @@ -501,7 +506,11 @@ def _emits_ol_events(task: BaseOperator | MappedOperator) -> bool:
)

emits_ol_events = all(
(config_selective_enabled, not config_disabled_for_operators, not is_skipped_as_empty_operator)
(
config_selective_enabled,
not config_disabled_for_operators,
not is_skipped_as_empty_operator,
)
)
return emits_ol_events

Expand Down Expand Up @@ -567,7 +576,12 @@ def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int
setattr(
item,
dict_key,
self._redact(subval, name=dict_key, depth=(depth + 1), max_depth=max_depth),
self._redact(
subval,
name=dict_key,
depth=(depth + 1),
max_depth=max_depth,
),
)
return item
elif is_json_serializable(item) and hasattr(item, "__dict__"):
Expand All @@ -578,7 +592,12 @@ def _redact(self, item: Redactable, name: str | None, depth: int, max_depth: int
setattr(
item,
dict_key,
self._redact(subval, name=dict_key, depth=(depth + 1), max_depth=max_depth),
self._redact(
subval,
name=dict_key,
depth=(depth + 1),
max_depth=max_depth,
),
)
return item
else:
Expand Down Expand Up @@ -641,13 +660,17 @@ def normalize_sql(sql: str | Iterable[str]):
def should_use_external_connection(hook) -> bool:
# If we're at Airflow 2.10, the execution is process-isolated, so we can safely run those again.
if not IS_AIRFLOW_2_10_OR_HIGHER:
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook", "RedshiftSQLHook"]
return hook.__class__.__name__ not in [
"SnowflakeHook",
"SnowflakeSqlApiHook",
"RedshiftSQLHook",
]
return True


def translate_airflow_asset(asset: Asset, lineage_context) -> OpenLineageDataset | None:
"""
Convert a Asset with an AIP-60 compliant URI to an OpenLineageDataset.
Convert an Asset with an AIP-60 compliant URI to an OpenLineageDataset.
This function returns None if no URI normalizer is defined, no asset converter is found or
some core Airflow changes are missing and ImportError is raised.
Expand Down
68 changes: 47 additions & 21 deletions providers/tests/openlineage/extractors/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@

import pytest
from openlineage.client.event_v2 import Dataset as OpenLineageDataset
from openlineage.client.facet_v2 import documentation_dataset, ownership_dataset, schema_dataset
from openlineage.client.facet_v2 import (
documentation_dataset,
ownership_dataset,
schema_dataset,
)

from airflow.io.path import ObjectStoragePath
from airflow.lineage.entities import Column, File, Table, User
Expand All @@ -44,17 +48,10 @@

@pytest.fixture
def hook_lineage_collector():
from importlib.util import find_spec

from airflow.lineage import hook

if find_spec("airflow.assets"):
# Dataset has been renamed as Asset in 3.0
from airflow.lineage.hook import get_hook_lineage_collector
else:
from airflow.providers.openlineage.utils.asset_compat_lineage_collector import (
get_hook_lineage_collector,
)
from airflow.providers.common.compat.lineage.hook import (
get_hook_lineage_collector,
)

hook._hook_lineage_collector = None
hook._hook_lineage_collector = hook.HookLineageCollector()
Expand All @@ -67,16 +64,34 @@ def hook_lineage_collector():
@pytest.mark.parametrize(
("uri", "dataset"),
(
("s3://bucket1/dir1/file1", OpenLineageDataset(namespace="s3://bucket1", name="dir1/file1")),
("gs://bucket2/dir2/file2", OpenLineageDataset(namespace="gs://bucket2", name="dir2/file2")),
("gcs://bucket3/dir3/file3", OpenLineageDataset(namespace="gs://bucket3", name="dir3/file3")),
(
"s3://bucket1/dir1/file1",
OpenLineageDataset(namespace="s3://bucket1", name="dir1/file1"),
),
(
"gs://bucket2/dir2/file2",
OpenLineageDataset(namespace="gs://bucket2", name="dir2/file2"),
),
(
"gcs://bucket3/dir3/file3",
OpenLineageDataset(namespace="gs://bucket3", name="dir3/file3"),
),
(
"hdfs://namenodehost:8020/file1",
OpenLineageDataset(namespace="hdfs://namenodehost:8020", name="file1"),
),
("hdfs://namenodehost/file2", OpenLineageDataset(namespace="hdfs://namenodehost", name="file2")),
("file://localhost/etc/fstab", OpenLineageDataset(namespace="file://localhost", name="etc/fstab")),
("file:///etc/fstab", OpenLineageDataset(namespace="file://", name="etc/fstab")),
(
"hdfs://namenodehost/file2",
OpenLineageDataset(namespace="hdfs://namenodehost", name="file2"),
),
(
"file://localhost/etc/fstab",
OpenLineageDataset(namespace="file://localhost", name="etc/fstab"),
),
(
"file:///etc/fstab",
OpenLineageDataset(namespace="file://", name="etc/fstab"),
),
("https://test.com", OpenLineageDataset(namespace="https", name="test.com")),
(
"https://test.com?param1=test1&param2=test2",
Expand Down Expand Up @@ -122,9 +137,18 @@ def test_convert_to_ol_dataset_from_object_storage_uri(uri, dataset):
File(url="file://localhost/etc/fstab"),
OpenLineageDataset(namespace="file://localhost", name="etc/fstab"),
),
(File(url="file:///etc/fstab"), OpenLineageDataset(namespace="file://", name="etc/fstab")),
(File(url="https://test.com"), OpenLineageDataset(namespace="https", name="test.com")),
(Table(cluster="c1", database="d1", name="t1"), OpenLineageDataset(namespace="c1", name="d1.t1")),
(
File(url="file:///etc/fstab"),
OpenLineageDataset(namespace="file://", name="etc/fstab"),
),
(
File(url="https://test.com"),
OpenLineageDataset(namespace="https", name="test.com"),
),
(
Table(cluster="c1", database="d1", name="t1"),
OpenLineageDataset(namespace="c1", name="d1.t1"),
),
("gs://bucket2/dir2/file2", None),
("not_an_url", None),
),
Expand Down Expand Up @@ -247,7 +271,9 @@ def test_extractor_manager_uses_hook_level_lineage(hook_lineage_collector):


@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0")
def test_extractor_manager_does_not_use_hook_level_lineage_when_operator(hook_lineage_collector):
def test_extractor_manager_does_not_use_hook_level_lineage_when_operator(
hook_lineage_collector,
):
class FakeSupportedOperator(BaseOperator):
def execute(self, context: Context) -> Any:
pass
Expand Down

0 comments on commit 1f0bba2

Please sign in to comment.