diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 7b29184110302..216bc9b1bcad6 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -1381,7 +1381,7 @@ def register_asset_changes_in_db( session=session, ) - def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str], set[str]]: + def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str, str], set[str]]: d = defaultdict(set) for event in outlet_events: try: @@ -1391,31 +1391,38 @@ def _asset_event_extras_from_aliases() -> dict[tuple[AssetUniqueKey, str], set[s if alias_name not in outlet_alias_names: continue asset_key = AssetUniqueKey(**event["dest_asset_key"]) - extra_json = json.dumps(event["extra"], sort_keys=True) - d[asset_key, extra_json].add(alias_name) + # fallback for backward compatibility + asset_extra_json = json.dumps(event.get("dest_asset_extra", {}), sort_keys=True) + asset_event_extra_json = json.dumps(event["extra"], sort_keys=True) + d[asset_key, asset_extra_json, asset_event_extra_json].add(alias_name) return d outlet_alias_names = {o.name for o in task_outlets if o.type == AssetAlias.__name__ and o.name} if outlet_alias_names and (event_extras_from_aliases := _asset_event_extras_from_aliases()): - for (asset_key, extra_json), event_aliase_names in event_extras_from_aliases.items(): - extra = json.loads(extra_json) + for ( + asset_key, + asset_extra_json, + asset_event_extras_json, + ), event_aliase_names in event_extras_from_aliases.items(): + asset_event_extra = json.loads(asset_event_extras_json) + asset = Asset(name=asset_key.name, uri=asset_key.uri, extra=json.loads(asset_extra_json)) ti.log.debug("register event for asset %s with aliases %s", asset_key, event_aliase_names) event = asset_manager.register_asset_change( task_instance=ti, - asset=asset_key, + asset=asset, source_alias_names=event_aliase_names, - extra=extra, + extra=asset_event_extra, session=session, ) if event is None: ti.log.info("Dynamically creating AssetModel %s", asset_key) - session.add(AssetModel(name=asset_key.name, uri=asset_key.uri)) + session.add(AssetModel.from_public(asset)) session.flush() # So event can set up its asset fk. asset_manager.register_asset_change( task_instance=ti, - asset=asset_key, + asset=asset, source_alias_names=event_aliase_names, - extra=extra, + extra=asset_event_extra, session=session, ) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 6a13ae1cdf72c..fcc923ba5929d 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -398,6 +398,8 @@ def decode_outlet_event_accessor(var: dict[str, Any]) -> OutletEventAccessor: dest_asset_key=AssetUniqueKey( name=e["dest_asset_key"]["name"], uri=e["dest_asset_key"]["uri"] ), + # fallback for backward compatibility + dest_asset_extra=e.get("dest_asset_extra", {}), extra=e["extra"], ) for e in asset_alias_events diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py b/airflow-core/tests/unit/serialization/test_serialized_objects.py index 14ca6a5a2cc9e..bf9c4f34e4460 100644 --- a/airflow-core/tests/unit/serialization/test_serialized_objects.py +++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py @@ -409,7 +409,8 @@ def __len__(self) -> int: AssetAliasEvent( source_alias_name="test_alias", dest_asset_key=AssetUniqueKey(name="test_name", uri="test://asset-uri"), - extra={}, + dest_asset_extra={"extra": "from asset itself"}, + extra={"extra": "from event"}, ) ], ), diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index 970e0385fd834..4e7b624051c9a 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -689,4 +689,5 @@ class AssetAliasEvent(attrs.AttrsInstance): source_alias_name: str dest_asset_key: AssetUniqueKey + dest_asset_extra: dict[str, JsonValue] extra: dict[str, JsonValue] diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 02b9e90d04a11..379a13a930e55 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -425,9 +425,9 @@ def __hash__(self): class _AssetRefResolutionMixin: - _asset_ref_cache: dict[AssetRef, AssetUniqueKey] = {} + _asset_ref_cache: dict[AssetRef, tuple[AssetUniqueKey, dict[str, JsonValue]]] = {} - def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: + def _resolve_asset_ref(self, ref: AssetRef) -> tuple[AssetUniqueKey, dict[str, JsonValue]]: with contextlib.suppress(KeyError): return self._asset_ref_cache[ref] @@ -442,8 +442,8 @@ def _resolve_asset_ref(self, ref: AssetRef) -> AssetUniqueKey: raise TypeError(f"Unimplemented asset ref: {type(ref)}") unique_key = AssetUniqueKey.from_asset(asset) for ref in refs_to_cache: - self._asset_ref_cache[ref] = unique_key - return unique_key + self._asset_ref_cache[ref] = (unique_key, asset.extra) + return (unique_key, asset.extra) # TODO: This is temporary to avoid code duplication between here & airflow/models/taskinstance.py @staticmethod @@ -488,14 +488,16 @@ def add(self, asset: Asset | AssetRef, extra: dict[str, JsonValue] | None = None return if isinstance(asset, AssetRef): - asset_key = self._resolve_asset_ref(asset) + asset_key, asset_extra = self._resolve_asset_ref(asset) else: asset_key = AssetUniqueKey.from_asset(asset) + asset_extra = asset.extra asset_alias_name = self.key.name event = AssetAliasEvent( source_alias_name=asset_alias_name, dest_asset_key=asset_key, + dest_asset_extra=asset_extra, extra=extra or {}, ) self.asset_alias_events.append(event) @@ -556,7 +558,7 @@ def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> OutletEventAccessor elif isinstance(key, AssetAlias): hashable_key = AssetAliasUniqueKey.from_asset_alias(key) elif isinstance(key, AssetRef): - hashable_key = self._resolve_asset_ref(key) + hashable_key, _ = self._resolve_asset_ref(key) else: raise TypeError(f"Key should be either an asset or an asset alias, not {type(key)}") @@ -684,7 +686,7 @@ def __getitem__(self, key: Asset | AssetAlias | AssetRef) -> Sequence[AssetEvent if isinstance(key, Asset): hashable_key = AssetUniqueKey.from_asset(key) elif isinstance(key, AssetRef): - hashable_key = self._resolve_asset_ref(key) + hashable_key, _ = self._resolve_asset_ref(key) elif isinstance(key, AssetAlias): hashable_key = AssetAliasUniqueKey.from_asset_alias(key) else: diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index df048d1bf4844..566c4ee815177 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -346,12 +346,13 @@ def test_nested_context(self): class TestOutletEventAccessor: @pytest.mark.parametrize( - "add_arg", + "add_args", [ - Asset("name", "uri"), - Asset.ref(name="name"), - Asset.ref(uri="uri"), + (Asset("name", "uri", extra={"extra": "from asset itself"}), {"extra": "from event"}), + (Asset.ref(name="name"), {"extra": "from event"}), + (Asset.ref(uri="uri"), {"extra": "from event"}), ], + ids=["asset", "asset name ref", "asset uri ref"], ) @pytest.mark.parametrize( "key, asset_alias_events", @@ -363,26 +364,31 @@ class TestOutletEventAccessor: AssetAliasEvent( source_alias_name="test_alias", dest_asset_key=AssetUniqueKey(name="name", uri="uri"), - extra={}, + dest_asset_extra={"extra": "from asset itself"}, + extra={"extra": "from event"}, ) ], ), ), + ids=["inactive asset", "active asset"], ) - def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms): - mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="") + def test_add(self, add_args, key, asset_alias_events, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = AssetResponse( + name="name", uri="uri", group="", extra={"extra": "from asset itself"} + ) outlet_event_accessor = OutletEventAccessor(key=key, extra={}) - outlet_event_accessor.add(add_arg) + outlet_event_accessor.add(*add_args) assert outlet_event_accessor.asset_alias_events == asset_alias_events @pytest.mark.parametrize( - "add_arg", + "add_args", [ - Asset("name", "uri"), - Asset.ref(name="name"), - Asset.ref(uri="uri"), + (Asset(name="name", uri="uri", extra={"extra": "from asset itself"}), {"extra": "from event"}), + (Asset.ref(name="name"), {"extra": "from event"}), + (Asset.ref(uri="uri"), {"extra": "from event"}), ], + ids=["asset", "asset name ref", "asset uri ref"], ) @pytest.mark.parametrize( "key, asset_alias_events", @@ -394,17 +400,21 @@ def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms): AssetAliasEvent( source_alias_name="test_alias", dest_asset_key=AssetUniqueKey(name="name", uri="uri"), - extra={}, + dest_asset_extra={"extra": "from asset itself"}, + extra={"extra": "from event"}, ) ], ), ), + ids=["inactive asset", "active asset"], ) - def test_add_with_db(self, add_arg, key, asset_alias_events, mock_supervisor_comms): - mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="") + def test_add_with_db(self, add_args, key, asset_alias_events, mock_supervisor_comms): + mock_supervisor_comms.send.return_value = AssetResponse( + name="name", uri="uri", group="", extra={"extra": "from asset itself"} + ) - outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""}) - outlet_event_accessor.add(add_arg, extra={}) + outlet_event_accessor = OutletEventAccessor(key=key) + outlet_event_accessor.add(*add_args) assert outlet_event_accessor.asset_alias_events == asset_alias_events