diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 6401dedc29340..e2fd72495c10c 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -1950,7 +1950,7 @@ def set_dag(what: StartupDetails, dag_id: str, task: TaskSDKBaseOperator) -> Run ) if hasattr(parse, "spy"): spy_agency.unspy(parse) - spy_agency.spy_on(parse, call_fake=lambda _: ti) + spy_agency.spy_on(parse, call_fake=lambda _, log: ti) return ti return set_dag diff --git a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py index ab3cdd56008cd..d96e82dc1ad9b 100644 --- a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py +++ b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py @@ -423,7 +423,7 @@ def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTas ) if hasattr(parse, "spy"): spy_agency.unspy(parse) - spy_agency.spy_on(parse, call_fake=lambda _: ti) + spy_agency.spy_on(parse, call_fake=lambda _, log: ti) return ti return set_dag diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 20bd5d271eb9d..a34a4c78ffa1a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -535,7 +535,7 @@ def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None: ) -def parse(what: StartupDetails) -> RuntimeTaskInstance: +def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # TODO: Task-SDK: # Using DagBag here is about 98% wrong, but it'll do for now @@ -558,12 +558,28 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance: if TYPE_CHECKING: assert what.ti.dag_id - dag = bag.dags[what.ti.dag_id] + try: + dag = bag.dags[what.ti.dag_id] + except KeyError: + log.error( + "DAG not found during start up", dag_id=what.ti.dag_id, bundle=bundle_info, path=what.dag_rel_path + ) + exit(1) # install_loader() - # TODO: Handle task not found - task = dag.task_dict[what.ti.task_id] + try: + task = dag.task_dict[what.ti.task_id] + except KeyError: + log.error( + "Task not found in DAG during start up", + dag_id=dag.dag_id, + task_id=what.ti.task_id, + bundle=bundle_info, + path=what.dag_rel_path, + ) + exit(1) + if not isinstance(task, (BaseOperator, MappedOperator)): raise TypeError( f"task is of the wrong type, got {type(task)}, wanted {BaseOperator} or {MappedOperator}" @@ -674,7 +690,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: setproctitle(f"airflow worker -- {msg.ti.id}") with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id): - ti = parse(msg) + ti = parse(msg, log) log.debug("DAG file parsed", file=msg.dag_rel_path) else: raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 1b795180f94bf..29a1738c0a990 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -209,7 +209,7 @@ def test_parse(test_dags_dir: Path, make_ti_context): ), }, ): - ti = parse(what) + ti = parse(what, mock.Mock()) assert ti.task assert ti.task.dag @@ -217,6 +217,65 @@ def test_parse(test_dags_dir: Path, make_ti_context): assert isinstance(ti.task.dag, DAG) +@pytest.mark.parametrize( + ("dag_id", "task_id", "expected_error"), + ( + pytest.param( + "madeup_dag_id", + "a", + mock.call(mock.ANY, dag_id="madeup_dag_id", path="super_basic.py"), + id="dag-not-found", + ), + pytest.param( + "super_basic", + "no-such-task", + mock.call(mock.ANY, task_id="no-such-task", dag_id="super_basic", path="super_basic.py"), + id="task-not-found", + ), + ), +) +def test_parse_not_found(test_dags_dir: Path, make_ti_context, dag_id, task_id, expected_error): + """Check for nice error messages on dag not found.""" + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id=task_id, + dag_id=dag_id, + run_id="c", + try_number=1, + ), + dag_rel_path="super_basic.py", + bundle_info=BundleInfo(name="my-bundle", version=None), + requests_fd=0, + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + ) + + log = mock.Mock() + + with ( + patch.dict( + os.environ, + { + "AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST": json.dumps( + [ + { + "name": "my-bundle", + "classpath": "airflow.dag_processing.bundles.local.LocalDagBundle", + "kwargs": {"path": str(test_dags_dir), "refresh_interval": 1}, + } + ] + ), + }, + ), + pytest.raises(SystemExit), + ): + parse(what, log) + + expected_error.kwargs["bundle"] = what.bundle_info + log.error.assert_has_calls([expected_error]) + + def test_run_deferred_basic(time_machine, create_runtime_ti, mock_supervisor_comms): """Test that a task can transition to a deferred state.""" from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync