diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py index 56bf1da68fd11..7f67d06ca3b6b 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks.py @@ -37,33 +37,36 @@ from airflow.exceptions import AirflowException from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook -GET_CLUSTER_ENDPOINT = ("GET", "api/2.0/clusters/get") -RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart") -START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") -TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete") - -CREATE_ENDPOINT = ("POST", "api/2.1/jobs/create") -RESET_ENDPOINT = ("POST", "api/2.1/jobs/reset") -UPDATE_ENDPOINT = ("POST", "api/2.1/jobs/update") -RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now") -SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit") -GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get") -CANCEL_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel") -DELETE_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/delete") -REPAIR_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/repair") -OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "api/2.1/jobs/runs/get-output") -CANCEL_ALL_RUNS_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel-all") - -INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install") -UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall") - -LIST_JOBS_ENDPOINT = ("GET", "api/2.1/jobs/list") -LIST_PIPELINES_ENDPOINT = ("GET", "api/2.0/pipelines") - -WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status") - -SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions") -SQL_STATEMENTS_ENDPOINT = "api/2.0/sql/statements" +GET_CLUSTER_ENDPOINT = ("GET", "2.0/clusters/get") +RESTART_CLUSTER_ENDPOINT = ("POST", "2.0/clusters/restart") +START_CLUSTER_ENDPOINT = ("POST", "2.0/clusters/start") +TERMINATE_CLUSTER_ENDPOINT = ("POST", "2.0/clusters/delete") + +CREATE_ENDPOINT = ("POST", "2.1/jobs/create") +RESET_ENDPOINT = ("POST", "2.1/jobs/reset") +UPDATE_ENDPOINT = ("POST", "2.1/jobs/update") +RUN_NOW_ENDPOINT = ("POST", "2.1/jobs/run-now") +SUBMIT_RUN_ENDPOINT = ("POST", "2.1/jobs/runs/submit") +GET_RUN_ENDPOINT = ("GET", "2.1/jobs/runs/get") +CANCEL_RUN_ENDPOINT = ("POST", "2.1/jobs/runs/cancel") +DELETE_RUN_ENDPOINT = ("POST", "2.1/jobs/runs/delete") +REPAIR_RUN_ENDPOINT = ("POST", "2.1/jobs/runs/repair") +OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "2.1/jobs/runs/get-output") +CANCEL_ALL_RUNS_ENDPOINT = ("POST", "2.1/jobs/runs/cancel-all") + +INSTALL_LIBS_ENDPOINT = ("POST", "2.0/libraries/install") +UNINSTALL_LIBS_ENDPOINT = ("POST", "2.0/libraries/uninstall") +UPDATE_REPO_ENDPOINT = ("PATCH", "2.0/repos/") +DELETE_REPO_ENDPOINT = ("DELETE", "2.0/repos/") +CREATE_REPO_ENDPOINT = ("POST", "2.0/repos") + +LIST_JOBS_ENDPOINT = ("GET", "2.1/jobs/list") +LIST_PIPELINES_ENDPOINT = ("GET", "2.0/pipelines") + +WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "2.0/workspace/get-status") + +SPARK_VERSIONS_ENDPOINT = ("GET", "2.0/clusters/spark-versions") +SQL_STATEMENTS_ENDPOINT = "2.0/sql/statements" class RunLifeCycleState(Enum): @@ -718,7 +721,8 @@ def update_repo(self, repo_id: str, json: dict[str, Any]) -> dict: :param json: payload :return: metadata from update """ - repos_endpoint = ("PATCH", f"api/2.0/repos/{repo_id}") + method, base_path = UPDATE_REPO_ENDPOINT + repos_endpoint = (method, f"{base_path}/{repo_id}") return self._do_api_call(repos_endpoint, json) def delete_repo(self, repo_id: str): @@ -728,7 +732,8 @@ def delete_repo(self, repo_id: str): :param repo_id: ID of Databricks Repos :return: """ - repos_endpoint = ("DELETE", f"api/2.0/repos/{repo_id}") + method, base_path = DELETE_REPO_ENDPOINT + repos_endpoint = (method, f"{base_path}/{repo_id}") self._do_api_call(repos_endpoint) def create_repo(self, json: dict[str, Any]) -> dict: @@ -738,8 +743,7 @@ def create_repo(self, json: dict[str, Any]) -> dict: :param json: payload :return: """ - repos_endpoint = ("POST", "api/2.0/repos") - return self._do_api_call(repos_endpoint, json) + return self._do_api_call(CREATE_REPO_ENDPOINT, json) def get_repo_by_path(self, path: str) -> str | None: """ diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py index cba505f9988a8..236ec3a80724a 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py @@ -637,8 +637,9 @@ def _do_api_call( """ method, endpoint = endpoint_info - # TODO: get rid of explicit 'api/' in the endpoint specification - url = self._endpoint_url(endpoint) + # Automatically prepend 'api/' prefix to all endpoint paths + full_endpoint = f"api/{endpoint}" + url = self._endpoint_url(full_endpoint) aad_headers = self._get_aad_headers() headers = {**self.user_agent_header, **aad_headers} @@ -704,7 +705,8 @@ async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, A """ method, endpoint = endpoint_info - url = self._endpoint_url(endpoint) + full_endpoint = f"api/{endpoint}" + url = self._endpoint_url(full_endpoint) aad_headers = await self._a_get_aad_headers() headers = {**self.user_agent_header, **aad_headers} diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py index c4b0b427d7dd6..039f33aeaf918 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py @@ -436,7 +436,7 @@ def test_do_api_call_custom_retry(self): def test_do_api_call_patch(self, mock_requests): mock_requests.patch.return_value.json.return_value = {"cluster_name": "new_name"} data = {"cluster_name": "new_name"} - patched_cluster_name = self.hook._do_api_call(("PATCH", "api/2.1/jobs/runs/submit"), data) + patched_cluster_name = self.hook._do_api_call(("PATCH", "2.1/jobs/runs/submit"), data) assert patched_cluster_name["cluster_name"] == "new_name" mock_requests.patch.assert_called_once_with( @@ -1365,7 +1365,7 @@ def setup_connections(self, create_connection_without_db): @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_do_api_call_respects_schema(self, mock_requests): mock_requests.get.return_value.json.return_value = {"foo": "bar"} - ret_val = self.hook._do_api_call(("GET", "api/2.1/foo/bar")) + ret_val = self.hook._do_api_call(("GET", "2.1/foo/bar")) assert ret_val == {"foo": "bar"} mock_requests.get.assert_called_once() @@ -1376,7 +1376,7 @@ def test_do_api_call_respects_schema(self, mock_requests): async def test_async_do_api_call_respects_schema(self, mock_get): mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value={"bar": "baz"}) async with self.hook: - run_page_url = await self.hook._a_do_api_call(("GET", "api/2.1/foo/bar")) + run_page_url = await self.hook._a_do_api_call(("GET", "2.1/foo/bar")) assert run_page_url == {"bar": "baz"} mock_get.assert_called_once() @@ -1390,7 +1390,7 @@ async def test_async_do_api_call_only_existing_response_properties_are_read(self response.mock_add_spec(aiohttp.ClientResponse, spec_set=True) response.json = AsyncMock(return_value={"bar": "baz"}) async with self.hook: - run_page_url = await self.hook._a_do_api_call(("GET", "api/2.1/foo/bar")) + run_page_url = await self.hook._a_do_api_call(("GET", "2.1/foo/bar")) assert run_page_url == {"bar": "baz"} mock_get.assert_called_once() @@ -1779,7 +1779,7 @@ async def test_do_api_call_patch(self, mock_patch): ) data = {"cluster_name": "new_name"} async with self.hook: - patched_cluster_name = await self.hook._a_do_api_call(("PATCH", "api/2.1/jobs/runs/submit"), data) + patched_cluster_name = await self.hook._a_do_api_call(("PATCH", "2.1/jobs/runs/submit"), data) assert patched_cluster_name["cluster_name"] == "new_name" mock_patch.assert_called_once_with(