Skip to content

Commit

Permalink
Add Dataset, Model asset subclasses (apache#43142)
Browse files Browse the repository at this point in the history
* feat(assets): add dataset subclass
* feat(assets): add model subclass
* feat(assets): make group a default instead of overwriting user input
* feat(assets): allow "airflow.Dataset" and "airflow.datasets.Dataset", "airflow.datasets.DatasetAlias" import for backward compat
  • Loading branch information
Lee-W authored and harjeevanmaan committed Oct 23, 2024
1 parent 14bdaba commit 1ef0b28
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 4 deletions.
5 changes: 4 additions & 1 deletion airflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
"DAG",
"Asset",
"XComArg",
# TODO: Remove this module in Airflow 3.2
"Dataset",
]

# Perform side-effects unless someone has explicitly opted out before import
Expand All @@ -83,12 +85,13 @@
"version": (".version", "", False),
# Deprecated lazy imports
"AirflowException": (".exceptions", "AirflowException", True),
"Dataset": (".assets", "Dataset", True),
}
if TYPE_CHECKING:
# These objects are imported by PEP-562, however, static analyzers and IDE's
# have no idea about typing of these objects.
# Add it under TYPE_CHECKING block should help with it.
from airflow.models.asset import Asset
from airflow.assets import Asset, Dataset
from airflow.models.dag import DAG
from airflow.models.xcom_arg import XComArg

Expand Down
19 changes: 16 additions & 3 deletions airflow/assets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from airflow.configuration import conf

__all__ = ["Asset", "AssetAll", "AssetAny"]
__all__ = ["Asset", "AssetAll", "AssetAny", "Dataset"]


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -275,13 +275,14 @@ def _set_extra_default(extra: dict | None) -> dict:

@attr.define(init=False, unsafe_hash=False)
class Asset(os.PathLike, BaseAsset):
"""A representation of data dependencies between workflows."""
"""A representation of data asset dependencies between workflows."""

name: str
uri: str
group: str
extra: dict[str, Any]

asset_type: ClassVar[str] = ""
__version__: ClassVar[int] = 1

@overload
Expand Down Expand Up @@ -313,7 +314,7 @@ def __init__(
fields = attr.fields_dict(Asset)
self.name = _validate_non_empty_identifier(self, fields["name"], name)
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group)
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)

def __fspath__(self) -> str:
Expand Down Expand Up @@ -372,6 +373,18 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
)


class Dataset(Asset):
"""A representation of dataset dependencies between workflows."""

asset_type: ClassVar[str] = "dataset"


class Model(Asset):
"""A representation of model dependencies between workflows."""

asset_type: ClassVar[str] = "model"


class _AssetBooleanCondition(BaseAsset):
"""Base class for asset boolean logic."""

Expand Down
45 changes: 45 additions & 0 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# We do not use "from __future__ import annotations" here because it is not supported
# by Pycharm when we want to make sure all imports in airflow work from namespace packages
# Adding it automatically is excluded in pyproject.toml via I002 ruff rule exclusion

# Make `airflow` a namespace package, supporting installing
# airflow.providers.* in different locations (i.e. one in site, and one in user
# lib.) This is required by some IDEs to resolve the import paths.
from __future__ import annotations

import warnings

from airflow.assets import AssetAlias as DatasetAlias, Dataset

# TODO: Remove this module in Airflow 3.2

warnings.warn(
"Import from the airflow.dataset module is deprecated and "
"will be removed in the Airflow 3.2. Please import it from 'airflow.assets'.",
DeprecationWarning,
stacklevel=2,
)


__all__ = [
"Dataset",
"DatasetAlias",
]
65 changes: 65 additions & 0 deletions tests/assets/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
AssetAll,
AssetAny,
BaseAsset,
Dataset,
Model,
_AssetAliasCondition,
_get_normalized_scheme,
_sanitize_uri,
Expand Down Expand Up @@ -612,3 +614,66 @@ def test_evalute(self, asset_alias_1, resolved_asset_alias_2, asset_1):

cond = _AssetAliasCondition(resolved_asset_alias_2.name)
assert cond.evaluate({asset_1.uri: True}) is True


class TestAssetSubclasses:
@pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, "dataset")))
def test_only_name(self, subcls, group):
obj = subcls(name="foobar")
assert obj.name == "foobar"
assert obj.uri == "foobar"
assert obj.group == group

@pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, "dataset")))
def test_only_uri(self, subcls, group):
obj = subcls(uri="s3://bucket/key/path")
assert obj.name == "s3://bucket/key/path"
assert obj.uri == "s3://bucket/key/path"
assert obj.group == group

@pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, "dataset")))
def test_both_name_and_uri(self, subcls, group):
obj = subcls("foobar", "s3://bucket/key/path")
assert obj.name == "foobar"
assert obj.uri == "s3://bucket/key/path"
assert obj.group == group

@pytest.mark.parametrize("arg", ["foobar", "s3://bucket/key/path"])
@pytest.mark.parametrize("subcls, group", ((Model, "model"), (Dataset, "dataset")))
def test_only_posarg(self, subcls, group, arg):
obj = subcls(arg)
assert obj.name == arg
assert obj.uri == arg
assert obj.group == group


@pytest.mark.parametrize(
"module_path, attr_name, warning_message",
(
(
"airflow",
"Dataset",
(
"Import 'Dataset' directly from the airflow module is deprecated and will be removed in the future. "
"Please import it from 'airflow.assets.Dataset'."
),
),
(
"airflow.datasets",
"Dataset",
(
"Import from the airflow.dataset module is deprecated and "
"will be removed in the Airflow 3.2. Please import it from 'airflow.assets'."
),
),
),
)
def test_backward_compat_import_before_airflow_3_2(module_path, attr_name, warning_message):
with pytest.warns() as record:
import importlib

mod = importlib.import_module(module_path, __name__)
getattr(mod, attr_name)

assert record[0].category is DeprecationWarning
assert str(record[0].message) == warning_message

0 comments on commit 1ef0b28

Please sign in to comment.