Skip to content

Commit

Permalink
[AIP-62] Translate AIP-60 URI to OpenLineage (#40173)
Browse files Browse the repository at this point in the history
* aip-62: implement translation mechanism from aip-60 to OpenLineage

Signed-off-by: Kacper Muda <mudakacper@gmail.com>

* aip-62: implement translation examples from aip-60 to OpenLineage

Signed-off-by: Kacper Muda <mudakacper@gmail.com>

---------

Signed-off-by: Kacper Muda <mudakacper@gmail.com>
  • Loading branch information
kacpermuda authored Jul 23, 2024
1 parent 8829860 commit 8a912f9
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 35 deletions.
30 changes: 29 additions & 1 deletion airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | N
return ProvidersManager().dataset_uri_handlers.get(scheme)


def _get_normalized_scheme(uri: str) -> str:
parsed = urllib.parse.urlsplit(uri)
return parsed.scheme.lower()


def _sanitize_uri(uri: str) -> str:
"""
Sanitize a dataset URI.
Expand All @@ -72,7 +77,8 @@ def _sanitize_uri(uri: str) -> str:
parsed = urllib.parse.urlsplit(uri)
if not parsed.scheme and not parsed.netloc: # Does not look like a URI.
return uri
normalized_scheme = parsed.scheme.lower()
if not (normalized_scheme := _get_normalized_scheme(uri)):
return uri
if normalized_scheme.startswith("x-"):
return uri
if normalized_scheme == "airflow":
Expand Down Expand Up @@ -231,6 +237,28 @@ def __eq__(self, other: Any) -> bool:
def __hash__(self) -> int:
return hash(self.uri)

@property
def normalized_uri(self) -> str | None:
"""
Returns the normalized and AIP-60 compliant URI whenever possible.
If we can't retrieve the scheme from URI or no normalizer is provided or if parsing fails,
it returns None.
If a normalizer for the scheme exists and parsing is successful we return the normalizer result.
"""
if not (normalized_scheme := _get_normalized_scheme(self.uri)):
return None

if (normalizer := _get_uri_normalizer(normalized_scheme)) is None:
return None
parsed = urllib.parse.urlsplit(self.uri)
try:
normalized_uri = normalizer(parsed)
return urllib.parse.urlunsplit(normalized_uri)
except ValueError:
return None

def as_expression(self) -> Any:
"""
Serialize the dataset into its scheduling expression.
Expand Down
4 changes: 4 additions & 0 deletions airflow/provider.yaml.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@
"factory": {
"type": ["string", "null"],
"description": "Dataset factory for specified URI. Creates AIP-60 compliant Dataset."
},
"to_openlineage_converter": {
"type": ["string", "null"],
"description": "OpenLineage converter function for specified URI schemes. Import path to a callable accepting a Dataset and LineageContext and returning OpenLineage dataset."
}
}
}
Expand Down
22 changes: 22 additions & 0 deletions airflow/providers/amazon/aws/datasets/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,30 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.datasets import Dataset
from airflow.providers.amazon.aws.hooks.s3 import S3Hook

if TYPE_CHECKING:
from urllib.parse import SplitResult

from openlineage.client.run import Dataset as OpenLineageDataset


def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset:
return Dataset(uri=f"s3://{bucket}/{key}", extra=extra)


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format s3:// must contain a bucket name")
return uri


def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset:
"""Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the hook."""
from openlineage.client.run import Dataset as OpenLineageDataset

bucket, key = S3Hook.parse_s3_url(dataset.uri)
return OpenLineageDataset(namespace=f"s3://{bucket}", name=key if key else "/")
3 changes: 2 additions & 1 deletion airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,8 @@ sensors:

dataset-uris:
- schemes: [s3]
handler: null
handler: airflow.providers.amazon.aws.datasets.s3.sanitize_uri
to_openlineage_converter: airflow.providers.amazon.aws.datasets.s3.convert_dataset_to_openlineage
factory: airflow.providers.amazon.aws.datasets.s3.create_dataset

filesystems:
Expand Down
26 changes: 26 additions & 0 deletions airflow/providers/common/io/datasets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,35 @@
# under the License.
from __future__ import annotations

import urllib.parse
from typing import TYPE_CHECKING

from airflow.datasets import Dataset

if TYPE_CHECKING:
from urllib.parse import SplitResult

from openlineage.client.run import Dataset as OpenLineageDataset


def create_dataset(*, path: str, extra=None) -> Dataset:
# We assume that we get absolute path starting with /
return Dataset(uri=f"file://{path}", extra=extra)


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.path:
raise ValueError("URI format file:// must contain a non-empty path.")
return uri


def convert_dataset_to_openlineage(dataset: Dataset, lineage_context) -> OpenLineageDataset:
"""
Translate Dataset with valid AIP-60 uri to OpenLineage with assistance from the context.
Windows paths are not standardized and can produce unexpected behaviour.
"""
from openlineage.client.run import Dataset as OpenLineageDataset

parsed = urllib.parse.urlsplit(dataset.uri)
return OpenLineageDataset(namespace=f"file://{parsed.netloc}", name=parsed.path)
3 changes: 2 additions & 1 deletion airflow/providers/common/io/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ xcom:

dataset-uris:
- schemes: [file]
handler: null
handler: airflow.providers.common.io.datasets.file.sanitize_uri
to_openlineage_converter: airflow.providers.common.io.datasets.file.convert_dataset_to_openlineage
factory: airflow.providers.common.io.datasets.file.create_dataset

config:
Expand Down
30 changes: 30 additions & 0 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from openlineage.client.run import Dataset as OpenLineageDataset

from airflow.models import DagRun, TaskInstance


Expand Down Expand Up @@ -635,3 +637,31 @@ def should_use_external_connection(hook) -> bool:
if not _IS_AIRFLOW_2_10_OR_HIGHER:
return hook.__class__.__name__ not in ["SnowflakeHook", "SnowflakeSqlApiHook", "RedshiftSQLHook"]
return True


def translate_airflow_dataset(dataset: Dataset, lineage_context) -> OpenLineageDataset | None:
"""
Convert a Dataset with an AIP-60 compliant URI to an OpenLineageDataset.
This function returns None if no URI normalizer is defined, no dataset converter is found or
some core Airflow changes are missing and ImportError is raised.
"""
try:
from airflow.datasets import _get_normalized_scheme
from airflow.providers_manager import ProvidersManager

ol_converters = ProvidersManager().dataset_to_openlineage_converters
normalized_uri = dataset.normalized_uri
except (ImportError, AttributeError):
return None

if normalized_uri is None:
return None

if not (normalized_scheme := _get_normalized_scheme(normalized_uri)):
return None

if (airflow_to_ol_converter := ol_converters.get(normalized_scheme)) is None:
return None

return airflow_to_ol_converter(Dataset(uri=normalized_uri, extra=dataset.extra), lineage_context)
82 changes: 57 additions & 25 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def __init__(self):
self._fs_set: set[str] = set()
self._dataset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {}
self._dataset_factories: dict[str, Callable[..., Dataset]] = {}
self._dataset_to_openlineage_converters: dict[str, Callable] = {}
self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment]
# keeps mapping between connection_types and hook class, package they come from
self._hook_provider_dict: dict[str, HookClassProvider] = {}
Expand Down Expand Up @@ -525,10 +526,10 @@ def initialize_providers_filesystems(self):
self._discover_filesystems()

@provider_info_cache("dataset_uris")
def initialize_providers_dataset_uri_handlers_and_factories(self):
"""Lazy initialization of provider dataset URI handlers."""
def initialize_providers_dataset_uri_resources(self):
"""Lazy initialization of provider dataset URI handlers, factories, converters etc."""
self.initialize_providers_list()
self._discover_dataset_uri_handlers_and_factories()
self._discover_dataset_uri_resources()

@provider_info_cache("hook_lineage_writers")
@provider_info_cache("taskflow_decorators")
Expand Down Expand Up @@ -881,28 +882,52 @@ def _discover_filesystems(self) -> None:
self._fs_set.add(fs_module_name)
self._fs_set = set(sorted(self._fs_set))

def _discover_dataset_uri_handlers_and_factories(self) -> None:
def _discover_dataset_uri_resources(self) -> None:
"""Discovers and registers dataset URI handlers, factories, and converters for all providers."""
from airflow.datasets import normalize_noop

for provider_package, provider in self._provider_dict.items():
for handler_info in provider.data.get("dataset-uris", []):
schemes = handler_info.get("schemes")
handler_path = handler_info.get("handler")
factory_path = handler_info.get("factory")
if schemes is None:
continue

if handler_path is not None and (
handler := _correctness_check(provider_package, handler_path, provider)
):
pass
else:
handler = normalize_noop
self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes)
if factory_path is not None and (
factory := _correctness_check(provider_package, factory_path, provider)
):
self._dataset_factories.update((scheme, factory) for scheme in schemes)
def _safe_register_resource(
provider_package_name: str,
schemes_list: list[str],
resource_path: str | None,
resource_registry: dict,
default_resource: Any = None,
):
"""
Register a specific resource (handler, factory, or converter) for the given schemes.
If the resolved resource (either from the path or the default) is valid, it updates
the resource registry with the appropriate resource for each scheme.
"""
resource = (
_correctness_check(provider_package_name, resource_path, provider)
if resource_path is not None
else default_resource
)
if resource:
resource_registry.update((scheme, resource) for scheme in schemes_list)

for provider_name, provider in self._provider_dict.items():
for uri_info in provider.data.get("dataset-uris", []):
if "schemes" not in uri_info or "handler" not in uri_info:
continue # Both schemas and handler must be explicitly set, handler can be set to null
common_args = {"schemes_list": uri_info["schemes"], "provider_package_name": provider_name}
_safe_register_resource(
resource_path=uri_info["handler"],
resource_registry=self._dataset_uri_handlers,
default_resource=normalize_noop,
**common_args,
)
_safe_register_resource(
resource_path=uri_info.get("factory"),
resource_registry=self._dataset_factories,
**common_args,
)
_safe_register_resource(
resource_path=uri_info.get("to_openlineage_converter"),
resource_registry=self._dataset_to_openlineage_converters,
**common_args,
)

def _discover_taskflow_decorators(self) -> None:
for name, info in self._provider_dict.items():
Expand Down Expand Up @@ -1301,14 +1326,21 @@ def filesystem_module_names(self) -> list[str]:

@property
def dataset_factories(self) -> dict[str, Callable[..., Dataset]]:
self.initialize_providers_dataset_uri_handlers_and_factories()
self.initialize_providers_dataset_uri_resources()
return self._dataset_factories

@property
def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]:
self.initialize_providers_dataset_uri_handlers_and_factories()
self.initialize_providers_dataset_uri_resources()
return self._dataset_uri_handlers

@property
def dataset_to_openlineage_converters(
self,
) -> dict[str, Callable]:
self.initialize_providers_dataset_uri_resources()
return self._dataset_to_openlineage_converters

@property
def provider_configs(self) -> list[tuple[str, dict[str, Any]]]:
self.initialize_providers_configuration()
Expand Down
48 changes: 43 additions & 5 deletions tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
DatasetAll,
DatasetAny,
_DatasetAliasCondition,
_get_normalized_scheme,
_sanitize_uri,
)
from airflow.models.dataset import DatasetAliasModel, DatasetDagRunQueue, DatasetModel
Expand Down Expand Up @@ -454,31 +455,68 @@ def test_datasets_expression_error(expression: Callable[[], None], error: str) -
assert str(info.value) == error


def mock_get_uri_normalizer(normalized_scheme):
def test_get_normalized_scheme():
assert _get_normalized_scheme("http://example.com") == "http"
assert _get_normalized_scheme("HTTPS://example.com") == "https"
assert _get_normalized_scheme("ftp://example.com") == "ftp"
assert _get_normalized_scheme("file://") == "file"

assert _get_normalized_scheme("example.com") == ""
assert _get_normalized_scheme("") == ""
assert _get_normalized_scheme(" ") == ""


def _mock_get_uri_normalizer_raising_error(normalized_scheme):
def normalizer(uri):
raise ValueError("Incorrect URI format")

return normalizer


@patch("airflow.datasets._get_uri_normalizer", mock_get_uri_normalizer)
def _mock_get_uri_normalizer_noop(normalized_scheme):
def normalizer(uri):
return uri

return normalizer


@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error)
@patch("airflow.datasets.warnings.warn")
def test__sanitize_uri_raises_warning(mock_warn):
def test_sanitize_uri_raises_warning(mock_warn):
_sanitize_uri("postgres://localhost:5432/database.schema.table")
msg = mock_warn.call_args.args[0]
assert "The dataset URI postgres://localhost:5432/database.schema.table is not AIP-60 compliant" in msg
assert "In Airflow 3, this will raise an exception." in msg


@patch("airflow.datasets._get_uri_normalizer", mock_get_uri_normalizer)
@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error)
@conf_vars({("core", "strict_dataset_uri_validation"): "True"})
def test__sanitize_uri_raises_exception():
def test_sanitize_uri_raises_exception():
with pytest.raises(ValueError) as e_info:
_sanitize_uri("postgres://localhost:5432/database.schema.table")
assert isinstance(e_info.value, ValueError)
assert str(e_info.value) == "Incorrect URI format"


@patch("airflow.datasets._get_uri_normalizer", lambda x: None)
def test_normalize_uri_no_normalizer_found():
dataset = Dataset(uri="any_uri_without_normalizer_defined")
assert dataset.normalized_uri is None


@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_raising_error)
def test_normalize_uri_invalid_uri():
dataset = Dataset(uri="any_uri_not_aip60_compliant")
assert dataset.normalized_uri is None


@patch("airflow.datasets._get_uri_normalizer", _mock_get_uri_normalizer_noop)
@patch("airflow.datasets._get_normalized_scheme", lambda x: "valid_scheme")
def test_normalize_uri_valid_uri():
dataset = Dataset(uri="valid_aip60_uri")
assert dataset.normalized_uri == "valid_aip60_uri"


@pytest.mark.db_test
@pytest.mark.usefixtures("clear_datasets")
class Test_DatasetAliasCondition:
Expand Down
Loading

0 comments on commit 8a912f9

Please sign in to comment.