diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 8d2bdf99795e0..512f8a4f0e6aa 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -875,8 +875,16 @@ async def create_triggers(self): await asyncio.sleep(0) try: - kwargs = Trigger._decrypt_kwargs(workload.encrypted_kwargs) - trigger_instance = trigger_class(**kwargs) + from airflow.serialization.serialized_objects import smart_decode_trigger_kwargs + + # Decrypt and clean trigger kwargs before for execution + # Note: We only clean up serialization artifacts (__var, __type keys) here, + # not in `_decrypt_kwargs` because it is used during hash comparison in + # add_asset_trigger_references and could lead to adverse effects like hash mismatches + # that could cause None values in collections. + kw = Trigger._decrypt_kwargs(workload.encrypted_kwargs) + deserialised_kwargs = {k: smart_decode_trigger_kwargs(v) for k, v in kw.items()} + trigger_instance = trigger_class(**deserialised_kwargs) except TypeError as err: self.log.error("Trigger failed to inflate", error=err) self.failed_triggers.append((trigger_id, err)) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index f3e49c4fabd34..e3cf0f019d5d1 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -337,19 +337,20 @@ def decode_asset_condition(var: dict[str, Any]) -> BaseAsset: raise ValueError(f"deserialization not implemented for DAT {dat!r}") -def decode_asset(var: dict[str, Any]): - def _smart_decode_trigger_kwargs(d): - """ - Slightly clean up kwargs for display. +def smart_decode_trigger_kwargs(d): + """ + Slightly clean up kwargs for display or execution. - This detects one level of BaseSerialization and tries to deserialize the - content, removing some __type __var ugliness when the value is displayed - in UI to the user. - """ - if not isinstance(d, dict) or Encoding.TYPE not in d: - return d - return BaseSerialization.deserialize(d) + This detects one level of BaseSerialization and tries to deserialize the + content, removing some __type __var ugliness when the value is displayed + in UI to the user and/or while execution. + """ + if not isinstance(d, dict) or Encoding.TYPE not in d: + return d + return BaseSerialization.deserialize(d) + +def decode_asset(var: dict[str, Any]): watchers = var.get("watchers", []) return Asset( name=var["name"], @@ -361,7 +362,7 @@ def _smart_decode_trigger_kwargs(d): name=watcher["name"], trigger={ "classpath": watcher["trigger"]["classpath"], - "kwargs": _smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]), + "kwargs": smart_decode_trigger_kwargs(watcher["trigger"]["kwargs"]), }, ) for watcher in watchers diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 08e08f08c4fcb..dd1f2b174c838 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -326,6 +326,51 @@ async def test_invalid_trigger(self, supervisor_builder): assert trigger_id == 1 assert traceback[-1] == "ModuleNotFoundError: No module named 'fake'\n" + @pytest.mark.asyncio + async def test_trigger_kwargs_serialization_cleanup(self, session): + """ + Test that trigger kwargs are properly cleaned of serialization artifacts + (__var, __type keys). + """ + from airflow.serialization.serialized_objects import BaseSerialization + + kw = {"simple": "test", "tuple": (), "dict": {}, "list": []} + + serialized_kwargs = BaseSerialization.serialize(kw) + + trigger_orm = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs=serialized_kwargs) + session.add(trigger_orm) + session.commit() + + stored_kwargs = trigger_orm.kwargs + assert stored_kwargs == { + "Encoding.TYPE": "dict", + "Encoding.VAR": { + "dict": {"Encoding.TYPE": "dict", "Encoding.VAR": {}}, + "list": [], + "simple": "test", + "tuple": {"Encoding.TYPE": "tuple", "Encoding.VAR": []}, + }, + } + + runner = TriggerRunner() + runner.to_create.append( + workloads.RunTrigger.model_construct( + id=trigger_orm.id, + ti=None, + classpath=trigger_orm.classpath, + encrypted_kwargs=trigger_orm.encrypted_kwargs, + ) + ) + + await runner.create_triggers() + assert trigger_orm.id in runner.triggers + trigger_instance = runner.triggers[trigger_orm.id]["task"] + + # The test passes if no exceptions were raised during trigger creation + trigger_instance.cancel() + await runner.cleanup_finished_triggers() + @pytest.mark.asyncio async def test_trigger_create_race_condition_38599(session, supervisor_builder):