Skip to content

Commit

Permalink
Add DatasetAlias to support dynamic Dataset Event Emission and Datase…
Browse files Browse the repository at this point in the history
…t Creation (#40478)

* feat(dataset_alias)
    * add DatasetAlias class
    * support yield dataset alias through datasets.Metadata
    * allow only one dataset event to triggered for the same dataset with the same extra in a single task
    * dynamically adding dataset through dataset_alias
* feat(datasets): add optional alias argument to dataset metadata
* feat(dag): add dataset aliases defined to db during dag parsing
* feat(datasets): register dataset change through dataset alias in outlet event
  • Loading branch information
Lee-W committed Jul 15, 2024
1 parent 469beef commit 3805050
Show file tree
Hide file tree
Showing 23 changed files with 753 additions and 76 deletions.
34 changes: 31 additions & 3 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import attr

from airflow.typing_compat import TypedDict

if TYPE_CHECKING:
from urllib.parse import SplitResult

Expand Down Expand Up @@ -106,16 +108,20 @@ def _sanitize_uri(uri: str) -> str:
return urllib.parse.urlunsplit(parsed)


def coerce_to_uri(value: str | Dataset) -> str:
def extract_event_key(value: str | Dataset | DatasetAlias) -> str:
"""
Coerce a user input into a sanitized URI.
Extract the key of an inlet or an outlet event.
If the input value is a string, it is treated as a URI and sanitized. If the
input is a :class:`Dataset`, the URI it contains is considered sanitized and
returned directly.
returned directly. If the input is a :class:`DatasetAlias`, the name it contains
will be returned directly.
:meta private:
"""
if isinstance(value, DatasetAlias):
return value.name

if isinstance(value, Dataset):
return value.uri
return _sanitize_uri(str(value))
Expand Down Expand Up @@ -159,6 +165,28 @@ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
raise NotImplementedError


@attr.define()
class DatasetAlias(BaseDataset):
"""A represeation of dataset alias which is used to create dataset during the runtime."""

name: str

def __eq__(self, other: Any) -> bool:
if isinstance(other, DatasetAlias):
return self.name == other.name
return NotImplemented

def __hash__(self) -> int:
return hash(self.name)


class DatasetAliasEvent(TypedDict):
"""A represeation of dataset event to be triggered by a dataset alias."""

source_alias_name: str
dest_dataset_uri: str


@attr.define()
class Dataset(os.PathLike, BaseDataset):
"""A representation of data dependencies between workflows."""
Expand Down
7 changes: 6 additions & 1 deletion airflow/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.listeners.listener import get_listener_manager
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.models.dataset import (
DagScheduleDatasetReference,
DatasetDagRunQueue,
DatasetEvent,
DatasetModel,
)
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down
13 changes: 10 additions & 3 deletions airflow/datasets/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import attrs

from airflow.datasets import coerce_to_uri
from airflow.datasets import DatasetAlias, extract_event_key

if TYPE_CHECKING:
from airflow.datasets import Dataset
Expand All @@ -33,7 +33,14 @@ class Metadata:

uri: str
extra: dict[str, Any]
alias_name: str | None = None

def __init__(self, target: str | Dataset, extra: dict[str, Any]) -> None:
self.uri = coerce_to_uri(target)
def __init__(
self, target: str | Dataset, extra: dict[str, Any], alias: DatasetAlias | str | None = None
) -> None:
self.uri = extract_event_key(target)
self.extra = extra
if isinstance(alias, DatasetAlias):
self.alias_name = alias.name
else:
self.alias_name = alias
59 changes: 59 additions & 0 deletions airflow/migrations/versions/0147_2_10_0_add_dataset_alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Add dataset_alias.
Revision ID: 05e19f3176be
Revises: d482b7261ff9
Create Date: 2024-07-05 08:17:12.017789
"""

from __future__ import annotations

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "05e19f3176be"
down_revision = "d482b7261ff9"
branch_labels = None
depends_on = None
airflow_version = "2.10.0"


def upgrade():
"""Add dataset_alias table."""
op.create_table(
"dataset_alias",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"name",
sa.String(length=3000).with_variant(
sa.String(length=3000, collation="latin1_general_cs"), "mysql"
),
nullable=False,
),
sa.PrimaryKeyConstraint("id", name=op.f("dataset_alias_pkey")),
)


def downgrade():
"""Drop dataset_alias table."""
op.drop_table("dataset_alias")
52 changes: 49 additions & 3 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
Text,
and_,
case,
delete,
func,
not_,
or_,
Expand All @@ -82,7 +83,7 @@
from airflow import settings, utils
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf as airflow_conf, secrets_backend_list
from airflow.datasets import BaseDataset, Dataset, DatasetAll
from airflow.datasets import BaseDataset, Dataset, DatasetAlias, DatasetAll
from airflow.datasets.manager import dataset_manager
from airflow.exceptions import (
AirflowDagInconsistent,
Expand All @@ -103,7 +104,7 @@
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
from airflow.models.dataset import DatasetDagRunQueue, DatasetModel
from airflow.models.dataset import DatasetAliasModel, DatasetDagRunQueue, DatasetModel
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import (
Context,
Expand Down Expand Up @@ -3298,6 +3299,7 @@ def bulk_write_to_db(
# We can't use a set here as we want to preserve order
outlet_datasets: dict[DatasetModel, None] = {}
input_datasets: dict[DatasetModel, None] = {}
outlet_dataset_alias_models: list[DatasetAliasModel] = []

# here we go through dags and tasks to check for dataset references
# if there are now None and previously there were some, we delete them
Expand All @@ -3314,7 +3316,14 @@ def bulk_write_to_db(
input_datasets[DatasetModel.from_public(dataset)] = None
curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
for task in dag.tasks:
dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)]
dataset_outlets: list[Dataset] = []
dataset_alias_outlets: list[DatasetAlias] = []
for outlet in task.outlets:
if isinstance(outlet, Dataset):
dataset_outlets.append(outlet)
elif isinstance(outlet, DatasetAlias):
dataset_alias_outlets.append(outlet)

if not dataset_outlets:
if curr_outlet_references:
this_task_outlet_refs = [
Expand All @@ -3324,9 +3333,14 @@ def bulk_write_to_db(
]
for ref in this_task_outlet_refs:
curr_outlet_references.remove(ref)

for d in dataset_outlets:
outlet_references[(task.dag_id, task.task_id)].add(d.uri)
outlet_datasets[DatasetModel.from_public(d)] = None

for d_a in dataset_alias_outlets:
outlet_dataset_alias_models.append(DatasetAliasModel.from_public(d_a))

all_datasets = outlet_datasets
all_datasets.update(input_datasets)

Expand All @@ -3351,6 +3365,38 @@ def bulk_write_to_db(
del new_datasets
del all_datasets

# store dataset aliases
new_dataset_alias_models: list[DatasetAliasModel] = []
if outlet_dataset_alias_models:
outlet_dataset_alias_names = [dataset_alias.name for dataset_alias in outlet_dataset_alias_models]

stored_dataset_alias_names = session.scalars(
select(DatasetAliasModel.name).where(DatasetAliasModel.name.in_(outlet_dataset_alias_names))
).fetchall()
removed_dataset_alias_names = session.scalars(
select(DatasetAliasModel.name).where(
DatasetAliasModel.name.not_in(outlet_dataset_alias_names)
)
).fetchall()

if stored_dataset_alias_names:
new_dataset_alias_models = [
dataset_alias_model
for dataset_alias_model in outlet_dataset_alias_models
if dataset_alias_model.name not in stored_dataset_alias_names
]
else:
new_dataset_alias_models = outlet_dataset_alias_models
session.add_all(new_dataset_alias_models)

if removed_dataset_alias_names:
session.execute(
delete(DatasetAliasModel).where(DatasetAliasModel.name.in_(removed_dataset_alias_names))
)

del new_dataset_alias_models
del outlet_dataset_alias_models

# reconcile dag-schedule-on-dataset references
for dag_id, uri_list in dag_references.items():
dag_refs_needed = {
Expand Down
30 changes: 29 additions & 1 deletion airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,41 @@
)
from sqlalchemy.orm import relationship

from airflow.datasets import Dataset
from airflow.datasets import Dataset, DatasetAlias
from airflow.models.base import Base, StringID
from airflow.settings import json
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime


class DatasetAliasModel(Base):
"""
A table to store dataset alias.
:param uri: a string that uniquely identifies the dataset alias
"""

id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(
String(length=3000).with_variant(
String(
length=3000,
# latin1 allows for more indexed length in mysql
# and this field should only be ascii chars
collation="latin1_general_cs",
),
"mysql",
),
nullable=False,
)

__tablename__ = "dataset_alias"

@classmethod
def from_public(cls, obj: DatasetAlias) -> DatasetAliasModel:
return cls(name=obj.name)


class DatasetModel(Base):
"""
A table to store datasets.
Expand Down
43 changes: 42 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call
from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.datasets import Dataset, DatasetAlias
from airflow.datasets.manager import dataset_manager
from airflow.exceptions import (
AirflowException,
Expand All @@ -91,6 +91,7 @@
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
from airflow.models.dagbag import DagBag
from airflow.models.dataset import DatasetModel
from airflow.models.log import Log
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import process_params
Expand Down Expand Up @@ -2911,6 +2912,10 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
if TYPE_CHECKING:
assert self.task

# One task only triggers one dataset event for each dataset with the same extra.
# This tuple[dataset uri, extra] to sets alias names mapping is used to find whether
# there're datasets with same uri but different extra that we need to emit more than one dataset events.
dataset_tuple_to_aliases_mapping: dict[tuple[str, frozenset], set[str]] = defaultdict(set)
for obj in self.task.outlets or []:
self.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides datasets
Expand All @@ -2921,6 +2926,42 @@ def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Se
extra=events[obj].extra,
session=session,
)
elif isinstance(obj, DatasetAlias):
if dataset_alias_event := events[obj].dataset_alias_event:
dataset_uri = dataset_alias_event["dest_dataset_uri"]
extra = events[obj].extra
frozen_extra = frozenset(extra.items())
dataset_alias_name = dataset_alias_event["source_alias_name"]

dataset_tuple_to_aliases_mapping[(dataset_uri, frozen_extra)].add(dataset_alias_name)

dataset_objs_cache: dict[str, DatasetModel] = {}
for (uri, extra_items), aliases in dataset_tuple_to_aliases_mapping.items():
if uri not in dataset_objs_cache:
dataset_obj = session.scalar(select(DatasetModel).where(DatasetModel.uri == uri).limit(1))
dataset_objs_cache[uri] = dataset_obj
else:
dataset_obj = dataset_objs_cache[uri]

if not dataset_obj:
dataset_obj = DatasetModel(uri=uri)
dataset_manager.create_datasets(dataset_models=[dataset_obj], session=session)
self.log.warning('Created a new Dataset(uri="%s") as it did not exists.', uri)
dataset_objs_cache[uri] = dataset_obj

extra = {k: v for k, v in extra_items}
self.log.info(
'Create dataset event Dataset(uri="%s", extra="%s") through dataset aliases "%s"',
uri,
extra,
", ".join(aliases),
)
dataset_manager.register_dataset_change(
task_instance=self,
dataset=dataset_obj,
extra=extra,
session=session,
)

def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False, *, session: Session):
"""Prepare Task for Execution."""
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DagAttributeTypes(str, Enum):
PARAM = "param"
XCOM_REF = "xcomref"
DATASET = "dataset"
DATASET_ALIAS = "dataset_alias"
DATASET_ANY = "dataset_any"
DATASET_ALL = "dataset_all"
SIMPLE_TASK_INSTANCE = "simple_task_instance"
Expand Down
Loading

0 comments on commit 3805050

Please sign in to comment.