diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index efc74506b62d4..a7787f1875a87 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -43,6 +43,7 @@ TextIO, cast, ) +from urllib.parse import urlparse from uuid import UUID import attrs @@ -1639,12 +1640,42 @@ def supervise( :param subprocess_logs_to_stdout: Should task logs also be sent to stdout via the main logger. :param client: Optional preconfigured client for communication with the server (Mostly for tests). :return: Exit code of the process. + :raises ValueError: If server URL is empty or invalid. """ # One or the other from airflow.sdk.execution_time.secrets_masker import reset_secrets_masker - if not client and ((not server) ^ dry_run): - raise ValueError(f"Can only specify one of {server=} or {dry_run=}") + if not client: + if dry_run and server: + raise ValueError(f"Can only specify one of {server=} or {dry_run=}") + + if not dry_run: + if not server: + raise ValueError( + "Invalid execution API server URL. Please ensure that a valid URL is configured." + ) + + try: + parsed_url = urlparse(server) + except Exception as e: + raise ValueError( + f"Invalid execution API server URL '{server}': {e}. " + "Please ensure that a valid URL is configured." + ) from e + + if parsed_url.scheme not in ("http", "https"): + raise ValueError( + f"Invalid execution API server URL '{server}': " + "URL must use http:// or https:// scheme. " + "Please ensure that a valid URL is configured." + ) + + if not parsed_url.netloc: + raise ValueError( + f"Invalid execution API server URL '{server}': " + "URL must include a valid host. " + "Please ensure that a valid URL is configured." + ) if not dag_rel_path: raise ValueError("dag_path is required") diff --git a/task-sdk/tests/conftest.py b/task-sdk/tests/conftest.py index 80c71f41a8fe8..d1866f523919c 100644 --- a/task-sdk/tests/conftest.py +++ b/task-sdk/tests/conftest.py @@ -20,6 +20,7 @@ import os from pathlib import Path from typing import TYPE_CHECKING, Any, NoReturn, Protocol +from unittest.mock import patch import pytest @@ -271,3 +272,12 @@ def _make_context_dict( return context.model_dump(exclude_unset=True, mode="json") return _make_context_dict + + +@pytest.fixture +def patched_secrets_masker(): + from airflow.sdk.execution_time.secrets_masker import SecretsMasker + + secrets_masker = SecretsMasker() + with patch("airflow.sdk.execution_time.secrets_masker._secrets_masker", return_value=secrets_masker): + yield secrets_masker diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 697581326706e..27b655a323db2 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -27,6 +27,7 @@ import socket import sys import time +from contextlib import nullcontext from operator import attrgetter from random import randint from time import sleep @@ -147,6 +148,58 @@ def client_with_ti_start(make_ti_context): return client +@pytest.mark.usefixtures("disable_capturing") +class TestSupervisor: + @pytest.mark.parametrize( + "server, dry_run, expectation", + [ + ("/execution/", False, pytest.raises(ValueError, match="Invalid execution API server URL")), + ("", False, pytest.raises(ValueError, match="Invalid execution API server URL")), + ("http://localhost:8080", True, pytest.raises(ValueError, match="Can only specify one of")), + (None, True, nullcontext()), + ("http://localhost:8080/execution/", False, nullcontext()), + ("https://localhost:8080/execution/", False, nullcontext()), + ], + ) + def test_supervise( + self, + patched_secrets_masker, + server, + dry_run, + expectation, + test_dags_dir, + client_with_ti_start, + ): + """ + Test that the supervisor validates server URL and dry_run parameter combinations correctly. + """ + ti = TaskInstance( + id=uuid7(), + task_id="async", + dag_id="super_basic_deferred_run", + run_id="d", + try_number=1, + dag_version_id=uuid7(), + ) + + bundle_info = BundleInfo(name="my-bundle", version=None) + + kw = { + "ti": ti, + "dag_rel_path": "super_basic_deferred_run.py", + "token": "", + "bundle_info": bundle_info, + "dry_run": dry_run, + "server": server, + } + if isinstance(expectation, nullcontext): + kw["client"] = client_with_ti_start + + with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)): + with expectation: + supervise(**kw) + + @pytest.mark.usefixtures("disable_capturing") class TestWatchedSubprocess: @pytest.fixture(autouse=True)