Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ def register_asset_changes_in_db(
session=session,
)

def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str], set[str]]:
def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str, str], set[str]]:
d = defaultdict(set)
for event in outlet_events:
try:
Expand All @@ -1391,31 +1391,38 @@ def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str], set[s
if alias_name not in outlet_alias_names:
continue
asset_key = AssetUniqueKey(**event["dest_asset_key"])
extra_json = json.dumps(event["extra"], sort_keys=True)
d[asset_key, extra_json].add(alias_name)
# fallback for backward compatibility
asset_extra_json = json.dumps(event.get("dest_asset_extra", {}), sort_keys=True)
asset_event_extra_json = json.dumps(event["extra"], sort_keys=True)
d[asset_key, asset_extra_json, asset_event_extra_json].add(alias_name)
return d

outlet_alias_names = {o.name for o in task_outlets if o.type == AssetAlias.__name__ and o.name}
if outlet_alias_names and (event_extras_from_aliases := _asset_event_extras_from_aliases()):
for (asset_key, extra_json), event_aliase_names in event_extras_from_aliases.items():
extra = json.loads(extra_json)
for (
asset_key,
asset_extra_json,
asset_event_extras_json,
), event_aliase_names in event_extras_from_aliases.items():
asset_event_extra = json.loads(asset_event_extras_json)
asset = Asset(name=asset_key.name, uri=asset_key.uri, extra=json.loads(asset_extra_json))
ti.log.debug("register event for asset %s with aliases %s", asset_key, event_aliase_names)
event = asset_manager.register_asset_change(
task_instance=ti,
asset=asset_key,
asset=asset,
source_alias_names=event_aliase_names,
extra=extra,
extra=asset_event_extra,
session=session,
)
if event is None:
ti.log.info("Dynamically creating AssetModel %s", asset_key)
session.add(AssetModel(name=asset_key.name, uri=asset_key.uri))
session.add(AssetModel.from_public(asset))
session.flush() # So event can set up its asset fk.
asset_manager.register_asset_change(
task_instance=ti,
asset=asset_key,
asset=asset,
source_alias_names=event_aliase_names,
extra=extra,
extra=asset_event_extra,
session=session,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor:
dest_asset_key=AssetUniqueKey(
name=e["dest_asset_key"]["name"], uri=e["dest_asset_key"]["uri"]
),
# fallback for backward compatibility
dest_asset_extra=e.get("dest_asset_extra", {}),
extra=e["extra"],
)
for e in asset_alias_events
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,8 @@ def __len__(self) -> int:
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="test_name", uri="test://asset-uri"),
extra={},
dest_asset_extra={"extra": "from asset itself"},
extra={"extra": "from event"},
)
],
),
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,4 +689,5 @@ class AssetAliasEvent(attrs.AttrsInstance):

source_alias_name: str
dest_asset_key: AssetUniqueKey
dest_asset_extra: dict[str, JsonValue]
extra: dict[str, JsonValue]
16 changes: 9 additions & 7 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,9 @@ def __hash__(self):


class _AssetRefResolutionMixin:
_asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {}
_asset_ref_cache: dict[AssetRef, tuple[AssetUniqueKey, dict[str, JsonValue]]] = {}

def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
def _resolve_asset_ref(self, ref: AssetRef) -> tuple[AssetUniqueKey, dict[str, JsonValue]]:
with contextlib.suppress(KeyError):
return self._asset_ref_cache[ref]

Expand All @@ -442,8 +442,8 @@ def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey:
raise TypeError(f"Unimplemented asset ref: {type(ref)}")
unique_key = AssetUniqueKey.from_asset(asset)
for ref in refs_to_cache:
self._asset_ref_cache[ref] = unique_key
return unique_key
self._asset_ref_cache[ref] = (unique_key, asset.extra)
return (unique_key, asset.extra)

# TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py
@staticmethod
Expand Down Expand Up @@ -488,14 +488,16 @@ def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None
return

if isinstance(asset, AssetRef):
asset_key = self._resolve_asset_ref(asset)
asset_key, asset_extra = self._resolve_asset_ref(asset)
else:
asset_key = AssetUniqueKey.from_asset(asset)
asset_extra = asset.extra

asset_alias_name = self.key.name
event = AssetAliasEvent(
source_alias_name=asset_alias_name,
dest_asset_key=asset_key,
dest_asset_extra=asset_extra,
extra=extra or {},
)
self.asset_alias_events.append(event)
Expand Down Expand Up @@ -556,7 +558,7 @@ def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessor
elif isinstance(key, AssetAlias):
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
elif isinstance(key, AssetRef):
hashable_key = self._resolve_asset_ref(key)
hashable_key, _ = self._resolve_asset_ref(key)
else:
raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}")

Expand Down Expand Up @@ -684,7 +686,7 @@ def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEvent
if isinstance(key, Asset):
hashable_key = AssetUniqueKey.from_asset(key)
elif isinstance(key, AssetRef):
hashable_key = self._resolve_asset_ref(key)
hashable_key, _ = self._resolve_asset_ref(key)
elif isinstance(key, AssetAlias):
hashable_key = AssetAliasUniqueKey.from_asset_alias(key)
else:
Expand Down
44 changes: 27 additions & 17 deletions task-sdk/tests/task_sdk/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,13 @@ def test_nested_context(self):

class TestOutletEventAccessor:
@pytest.mark.parametrize(
"add_arg",
"add_args",
[
Asset("name", "uri"),
Asset.ref(name="name"),
Asset.ref(uri="uri"),
(Asset("name", "uri", extra={"extra": "from asset itself"}), {"extra": "from event"}),
(Asset.ref(name="name"), {"extra": "from event"}),
(Asset.ref(uri="uri"), {"extra": "from event"}),
],
ids=["asset", "asset name ref", "asset uri ref"],
)
@pytest.mark.parametrize(
"key, asset_alias_events",
Expand All @@ -363,26 +364,31 @@ class TestOutletEventAccessor:
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="name", uri="uri"),
extra={},
dest_asset_extra={"extra": "from asset itself"},
extra={"extra": "from event"},
)
],
),
),
ids=["inactive asset", "active asset"],
)
def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="")
def test_add(self, add_args, key, asset_alias_events, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = AssetResponse(
name="name", uri="uri", group="", extra={"extra": "from asset itself"}
)

outlet_event_accessor = OutletEventAccessor(key=key, extra={})
outlet_event_accessor.add(add_arg)
outlet_event_accessor.add(*add_args)
assert outlet_event_accessor.asset_alias_events == asset_alias_events

@pytest.mark.parametrize(
"add_arg",
"add_args",
[
Asset("name", "uri"),
Asset.ref(name="name"),
Asset.ref(uri="uri"),
(Asset(name="name", uri="uri", extra={"extra": "from asset itself"}), {"extra": "from event"}),
(Asset.ref(name="name"), {"extra": "from event"}),
(Asset.ref(uri="uri"), {"extra": "from event"}),
],
ids=["asset", "asset name ref", "asset uri ref"],
)
@pytest.mark.parametrize(
"key, asset_alias_events",
Expand All @@ -394,17 +400,21 @@ def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="name", uri="uri"),
extra={},
dest_asset_extra={"extra": "from asset itself"},
extra={"extra": "from event"},
)
],
),
),
ids=["inactive asset", "active asset"],
)
def test_add_with_db(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="")
def test_add_with_db(self, add_args, key, asset_alias_events, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = AssetResponse(
name="name", uri="uri", group="", extra={"extra": "from asset itself"}
)

outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""})
outlet_event_accessor.add(add_arg, extra={})
outlet_event_accessor = OutletEventAccessor(key=key)
outlet_event_accessor.add(*add_args)
assert outlet_event_accessor.asset_alias_events == asset_alias_events


Expand Down