Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,7 +1950,7 @@ def set_dag(what: StartupDetails, dag_id: str, task: TaskSDKBaseOperator) -> Run
)
if hasattr(parse, "spy"):
spy_agency.unspy(parse)
spy_agency.spy_on(parse, call_fake=lambda _: ti)
spy_agency.spy_on(parse, call_fake=lambda _, log: ti)
return ti

return set_dag
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTas
)
if hasattr(parse, "spy"):
spy_agency.unspy(parse)
spy_agency.spy_on(parse, call_fake=lambda _: ti)
spy_agency.spy_on(parse, call_fake=lambda _, log: ti)
return ti

return set_dag
Expand Down
26 changes: 21 additions & 5 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None:
)


def parse(what: StartupDetails) -> RuntimeTaskInstance:
def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
# TODO: Task-SDK:
# Using DagBag here is about 98% wrong, but it'll do for now

Expand All @@ -558,12 +558,28 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
if TYPE_CHECKING:
assert what.ti.dag_id

dag = bag.dags[what.ti.dag_id]
try:
dag = bag.dags[what.ti.dag_id]
except KeyError:
log.error(
"DAG not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path
)
exit(1)

# install_loader()

# TODO: Handle task not found
task = dag.task_dict[what.ti.task_id]
try:
task = dag.task_dict[what.ti.task_id]
except KeyError:
log.error(
"Task not found in DAG during start up",
dag_id=dag.dag_id,
task_id=what.ti.task_id,
bundle=bundle_info,
path=what.dag_rel_path,
)
exit(1)

if not isinstance(task, (BaseOperator, MappedOperator)):
raise TypeError(
f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}"
Expand Down Expand Up @@ -674,7 +690,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
setproctitle(f"airflow worker -- {msg.ti.id}")

with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
ti = parse(msg)
ti = parse(msg, log)
log.debug("DAG file parsed", file=msg.dag_rel_path)
else:
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
Expand Down
61 changes: 60 additions & 1 deletion task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,73 @@ def test_parse(test_dags_dir: Path, make_ti_context):
),
},
):
ti = parse(what)
ti = parse(what, mock.Mock())

assert ti.task
assert ti.task.dag
assert isinstance(ti.task, BaseOperator)
assert isinstance(ti.task.dag, DAG)


@pytest.mark.parametrize(
("dag_id", "task_id", "expected_error"),
(
pytest.param(
"madeup_dag_id",
"a",
mock.call(mock.ANY, dag_id="madeup_dag_id", path="super_basic.py"),
id="dag-not-found",
),
pytest.param(
"super_basic",
"no-such-task",
mock.call(mock.ANY, task_id="no-such-task", dag_id="super_basic", path="super_basic.py"),
id="task-not-found",
),
),
)
def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, expected_error):
"""Check for nice error messages on dag not found."""
what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id=task_id,
dag_id=dag_id,
run_id="c",
try_number=1,
),
dag_rel_path="super_basic.py",
bundle_info=BundleInfo(name="my-bundle", version=None),
requests_fd=0,
ti_context=make_ti_context(),
start_date=timezone.utcnow(),
)

log = mock.Mock()

with (
patch.dict(
os.environ,
{
"AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST": json.dumps(
[
{
"name": "my-bundle",
"classpath": "airflow.dag_processing.bundles.local.LocalDagBundle",
"kwargs": {"path": str(test_dags_dir), "refresh_interval": 1},
}
]
),
},
),
pytest.raises(SystemExit),
):
parse(what, log)

expected_error.kwargs["bundle"] = what.bundle_info
log.error.assert_has_calls([expected_error])


def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_comms):
"""Test that a task can transition to a deferred state."""
from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync
Expand Down
Loading