diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 4e26fce37e779..b589767731bcb 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -31,6 +31,7 @@ from inspect import signature from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast, overload from urllib.parse import urlsplit +from uuid import UUID import attrs import jinja2 @@ -1389,13 +1390,13 @@ def _run_task( # it is run. ti.set_state(TaskInstanceState.QUEUED) task_sdk_ti = TaskInstanceSDK( - id=ti.id, + id=UUID(str(ti.id)), task_id=ti.task_id, dag_id=ti.dag_id, run_id=ti.run_id, try_number=ti.try_number, map_index=ti.map_index, - dag_version_id=ti.dag_version_id, + dag_version_id=UUID(str(ti.dag_version_id)), ) taskrun_result = run_task_in_process(ti=task_sdk_ti, task=task) @@ -1414,7 +1415,15 @@ def _run_task( trigger = import_string(msg.classpath)(**msg.trigger_kwargs) event = _run_inline_trigger(trigger, task_sdk_ti) ti.next_method = msg.next_method - ti.next_kwargs = {"event": event.payload} if event else msg.next_kwargs + + # Deserialize next_kwargs if it's a string (encrypted dict), similar to what the API server does + next_kwargs_value = {"event": event.payload} if event else msg.next_kwargs + if isinstance(next_kwargs_value, str): + from airflow.serialization.serialized_objects import BaseSerialization + + ti.next_kwargs = BaseSerialization.deserialize(next_kwargs_value) + else: + ti.next_kwargs = next_kwargs_value log.info("[DAG TEST] Trigger completed") # Set the state to SCHEDULED so that the task can be resumed.