Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,33 +37,36 @@
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

GET_CLUSTER_ENDPOINT = ("GET", "api/2.0/clusters/get")
RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")

CREATE_ENDPOINT = ("POST", "api/2.1/jobs/create")
RESET_ENDPOINT = ("POST", "api/2.1/jobs/reset")
UPDATE_ENDPOINT = ("POST", "api/2.1/jobs/update")
RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now")
SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit")
GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get")
CANCEL_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel")
DELETE_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/delete")
REPAIR_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/repair")
OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "api/2.1/jobs/runs/get-output")
CANCEL_ALL_RUNS_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel-all")

INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install")
UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall")

LIST_JOBS_ENDPOINT = ("GET", "api/2.1/jobs/list")
LIST_PIPELINES_ENDPOINT = ("GET", "api/2.0/pipelines")

WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")

SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
SQL_STATEMENTS_ENDPOINT = "api/2.0/sql/statements"
GET_CLUSTER_ENDPOINT = ("GET", "2.0/clusters/get")
RESTART_CLUSTER_ENDPOINT = ("POST", "2.0/clusters/restart")
START_CLUSTER_ENDPOINT = ("POST", "2.0/clusters/start")
TERMINATE_CLUSTER_ENDPOINT = ("POST", "2.0/clusters/delete")

CREATE_ENDPOINT = ("POST", "2.1/jobs/create")
RESET_ENDPOINT = ("POST", "2.1/jobs/reset")
UPDATE_ENDPOINT = ("POST", "2.1/jobs/update")
RUN_NOW_ENDPOINT = ("POST", "2.1/jobs/run-now")
SUBMIT_RUN_ENDPOINT = ("POST", "2.1/jobs/runs/submit")
GET_RUN_ENDPOINT = ("GET", "2.1/jobs/runs/get")
CANCEL_RUN_ENDPOINT = ("POST", "2.1/jobs/runs/cancel")
DELETE_RUN_ENDPOINT = ("POST", "2.1/jobs/runs/delete")
REPAIR_RUN_ENDPOINT = ("POST", "2.1/jobs/runs/repair")
OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "2.1/jobs/runs/get-output")
CANCEL_ALL_RUNS_ENDPOINT = ("POST", "2.1/jobs/runs/cancel-all")

INSTALL_LIBS_ENDPOINT = ("POST", "2.0/libraries/install")
UNINSTALL_LIBS_ENDPOINT = ("POST", "2.0/libraries/uninstall")
UPDATE_REPO_ENDPOINT = ("PATCH", "2.0/repos/")
DELETE_REPO_ENDPOINT = ("DELETE", "2.0/repos/")
CREATE_REPO_ENDPOINT = ("POST", "2.0/repos")

LIST_JOBS_ENDPOINT = ("GET", "2.1/jobs/list")
LIST_PIPELINES_ENDPOINT = ("GET", "2.0/pipelines")

WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "2.0/workspace/get-status")

SPARK_VERSIONS_ENDPOINT = ("GET", "2.0/clusters/spark-versions")
SQL_STATEMENTS_ENDPOINT = "2.0/sql/statements"


class RunLifeCycleState(Enum):
Expand Down Expand Up @@ -718,7 +721,8 @@ def update_repo(self, repo_id: str, json: dict[str, Any]) -> dict:
:param json: payload
:return: metadata from update
"""
repos_endpoint = ("PATCH", f"api/2.0/repos/{repo_id}")
method, base_path = UPDATE_REPO_ENDPOINT
repos_endpoint = (method, f"{base_path}/{repo_id}")
return self._do_api_call(repos_endpoint, json)

def delete_repo(self, repo_id: str):
Expand All @@ -728,7 +732,8 @@ def delete_repo(self, repo_id: str):
:param repo_id: ID of Databricks Repos
:return:
"""
repos_endpoint = ("DELETE", f"api/2.0/repos/{repo_id}")
method, base_path = DELETE_REPO_ENDPOINT
repos_endpoint = (method, f"{base_path}/{repo_id}")
self._do_api_call(repos_endpoint)

def create_repo(self, json: dict[str, Any]) -> dict:
Expand All @@ -738,8 +743,7 @@ def create_repo(self, json: dict[str, Any]) -> dict:
:param json: payload
:return:
"""
repos_endpoint = ("POST", "api/2.0/repos")
return self._do_api_call(repos_endpoint, json)
return self._do_api_call(CREATE_REPO_ENDPOINT, json)

def get_repo_by_path(self, path: str) -> str | None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,9 @@ def _do_api_call(
"""
method, endpoint = endpoint_info

# TODO: get rid of explicit 'api/' in the endpoint specification
url = self._endpoint_url(endpoint)
# Automatically prepend 'api/' prefix to all endpoint paths
full_endpoint = f"api/{endpoint}"
url = self._endpoint_url(full_endpoint)

aad_headers = self._get_aad_headers()
headers = {**self.user_agent_header, **aad_headers}
Expand Down Expand Up @@ -704,7 +705,8 @@ async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, A
"""
method, endpoint = endpoint_info

url = self._endpoint_url(endpoint)
full_endpoint = f"api/{endpoint}"
url = self._endpoint_url(full_endpoint)

aad_headers = await self._a_get_aad_headers()
headers = {**self.user_agent_header, **aad_headers}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_do_api_call_custom_retry(self):
def test_do_api_call_patch(self, mock_requests):
mock_requests.patch.return_value.json.return_value = {"cluster_name": "new_name"}
data = {"cluster_name": "new_name"}
patched_cluster_name = self.hook._do_api_call(("PATCH", "api/2.1/jobs/runs/submit"), data)
patched_cluster_name = self.hook._do_api_call(("PATCH", "2.1/jobs/runs/submit"), data)

assert patched_cluster_name["cluster_name"] == "new_name"
mock_requests.patch.assert_called_once_with(
Expand Down Expand Up @@ -1365,7 +1365,7 @@ def setup_connections(self, create_connection_without_db):
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_do_api_call_respects_schema(self, mock_requests):
mock_requests.get.return_value.json.return_value = {"foo": "bar"}
ret_val = self.hook._do_api_call(("GET", "api/2.1/foo/bar"))
ret_val = self.hook._do_api_call(("GET", "2.1/foo/bar"))

assert ret_val == {"foo": "bar"}
mock_requests.get.assert_called_once()
Expand All @@ -1376,7 +1376,7 @@ def test_do_api_call_respects_schema(self, mock_requests):
async def test_async_do_api_call_respects_schema(self, mock_get):
mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value={"bar": "baz"})
async with self.hook:
run_page_url = await self.hook._a_do_api_call(("GET", "api/2.1/foo/bar"))
run_page_url = await self.hook._a_do_api_call(("GET", "2.1/foo/bar"))

assert run_page_url == {"bar": "baz"}
mock_get.assert_called_once()
Expand All @@ -1390,7 +1390,7 @@ async def test_async_do_api_call_only_existing_response_properties_are_read(self
response.mock_add_spec(aiohttp.ClientResponse, spec_set=True)
response.json = AsyncMock(return_value={"bar": "baz"})
async with self.hook:
run_page_url = await self.hook._a_do_api_call(("GET", "api/2.1/foo/bar"))
run_page_url = await self.hook._a_do_api_call(("GET", "2.1/foo/bar"))

assert run_page_url == {"bar": "baz"}
mock_get.assert_called_once()
Expand Down Expand Up @@ -1779,7 +1779,7 @@ async def test_do_api_call_patch(self, mock_patch):
)
data = {"cluster_name": "new_name"}
async with self.hook:
patched_cluster_name = await self.hook._a_do_api_call(("PATCH", "api/2.1/jobs/runs/submit"), data)
patched_cluster_name = await self.hook._a_do_api_call(("PATCH", "2.1/jobs/runs/submit"), data)

assert patched_cluster_name["cluster_name"] == "new_name"
mock_patch.assert_called_once_with(
Expand Down