diff --git a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py index d453862598743..19573d7a96a12 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py @@ -64,7 +64,7 @@ @providers_configuration_loaded def force_use_internal_api_on_edge_worker(): """ - Ensure that the environment is configured for the internal API without needing to declare it outside. + Ensure the environment is configured for the internal API without explicit declaration. This is only required for an Edge worker and must to be done before the Click CLI wrapper is initiated. That is because the CLI wrapper will attempt to establish a DB connection, which will fail before the diff --git a/providers/edge3/src/airflow/providers/edge3/cli/worker.py b/providers/edge3/src/airflow/providers/edge3/cli/worker.py index b8674b5e53d3a..d33462c3b526d 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/worker.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/worker.py @@ -26,6 +26,7 @@ from subprocess import Popen from time import sleep from typing import TYPE_CHECKING +from urllib.parse import urlparse from lockfile.pidlockfile import remove_existing_pidfile from requests import HTTPError @@ -186,11 +187,13 @@ def _run_job_via_supervisor( setproctitle(f"airflow edge worker: {workload.ti.key}") try: - base_url = conf.get("api", "base_url", fallback="/") - # If it's a relative URL, use localhost:8080 as the default - if base_url.startswith("/"): - base_url = f"http://localhost:8080{base_url}" - default_execution_api_server = f"{base_url.rstrip('/')}/execution/" + api_url = conf.get("edge", "api_url") + execution_api_server_url = conf.get("core", "execution_api_server_url", fallback=...) + if execution_api_server_url is ...: + parsed = urlparse(api_url) + execution_api_server_url = f"{parsed.scheme}://{parsed.netloc}/execution/" + + logger.info("Worker starting up server=execution_api_server_url=%s", execution_api_server_url) supervise( # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. @@ -199,9 +202,7 @@ def _run_job_via_supervisor( dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, token=workload.token, - server=conf.get( - "core", "execution_api_server_url", fallback=default_execution_api_server - ), + server=execution_api_server_url, log_path=workload.log_path, ) return 0 diff --git a/providers/edge3/tests/unit/edge3/cli/test_worker.py b/providers/edge3/tests/unit/edge3/cli/test_worker.py index 3561979a4b776..5133d92ed37b9 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_worker.py +++ b/providers/edge3/tests/unit/edge3/cli/test_worker.py @@ -161,6 +161,52 @@ def test_launch_job(self, mock_popen, mock_logfile_path, mock_process, worker_wi assert len(EdgeWorker.jobs) == 1 assert EdgeWorker.jobs[0].edge_job == edge_job + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+") + @pytest.mark.parametrize( + "configs, expected_url", + [ + ( + {("edge", "api_url"): "https://api-endpoint"}, + "https://api-endpoint/execution/", + ), + ( + {("edge", "api_url"): "https://api:1234/endpoint"}, + "https://api:1234/execution/", + ), + ( + { + ("edge", "api_url"): "https://api-endpoint", + ("core", "execution_api_server_url"): "https://other-endpoint", + }, + "https://other-endpoint", + ), + ], + ) + @patch("airflow.sdk.execution_time.supervisor.supervise") + @patch("airflow.providers.edge3.cli.worker.Process") + @patch("airflow.providers.edge3.cli.worker.Popen") + def test_use_execution_api_server_url( + self, + mock_popen, + mock_process, + mock_supervise, + configs, + expected_url, + worker_with_job: EdgeWorker, + ): + mock_popen.side_effect = [MagicMock()] + mock_process_instance = MagicMock() + mock_process.side_effect = [mock_process_instance] + + edge_job = EdgeWorker.jobs.pop().edge_job + with conf_vars(configs): + worker_with_job._launch_job(edge_job) + + mock_process_callback = mock_process.call_args.kwargs["target"] + mock_process_callback(workload=MagicMock()) + + assert mock_supervise.call_args.kwargs["server"] == expected_url + @pytest.mark.parametrize( "reserve_result, fetch_result, expected_calls", [