Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DatabricksHook ClusterState #34643

Merged
merged 7 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 93 additions & 16 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

GET_CLUSTER_ENDPOINT = ("GET", "api/2.0/clusters/get")
RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")
Expand All @@ -57,38 +58,39 @@

WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")

RUN_LIFE_CYCLE_STATES = [
"PENDING",
"RUNNING",
"TERMINATING",
"TERMINATED",
"SKIPPED",
"INTERNAL_ERROR",
"QUEUED",
]

SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")


class RunState:
"""Utility class for the run state concept of Databricks runs."""

RUN_LIFE_CYCLE_STATES = [
"PENDING",
"RUNNING",
"TERMINATING",
"TERMINATED",
"SKIPPED",
"INTERNAL_ERROR",
"QUEUED",
]

def __init__(
self, life_cycle_state: str, result_state: str = "", state_message: str = "", *args, **kwargs
) -> None:
if life_cycle_state not in self.RUN_LIFE_CYCLE_STATES:
raise AirflowException(
f"Unexpected life cycle state: {life_cycle_state}: If the state has "
"been introduced recently, please check the Databricks user "
"guide for troubleshooting information"
)

self.life_cycle_state = life_cycle_state
self.result_state = result_state
self.state_message = state_message

@property
def is_terminal(self) -> bool:
"""True if the current state is a terminal state."""
if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES:
raise AirflowException(
f"Unexpected life cycle state: {self.life_cycle_state}: If the state has "
"been introduced recently, please check the Databricks user "
"guide for troubleshooting information"
)
return self.life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR")

@property
Expand Down Expand Up @@ -116,6 +118,55 @@ def from_json(cls, data: str) -> RunState:
return RunState(**json.loads(data))


class ClusterState:
"""Utility class for the cluster state concept of Databricks cluster."""

CLUSTER_LIFE_CYCLE_STATES = [
"PENDING",
"RUNNING",
"RESTARTING",
"RESIZING",
"TERMINATING",
"TERMINATED",
"ERROR",
"UNKNOWN",
]

def __init__(self, state: str = "", state_message: str = "", *args, **kwargs) -> None:
if state not in self.CLUSTER_LIFE_CYCLE_STATES:
raise AirflowException(
f"Unexpected cluster life cycle state: {state}: If the state has "
"been introduced recently, please check the Databricks user "
"guide for troubleshooting information"
)

self.state = state
self.state_message = state_message

@property
def is_terminal(self) -> bool:
"""True if the current state is a terminal state."""
return self.state in ("TERMINATING", "TERMINATED", "ERROR", "UNKNOWN")
Seokyun-Ha marked this conversation as resolved.
Show resolved Hide resolved

@property
def is_running(self) -> bool:
"""True if the current state is running."""
return self.state in ("RUNNING", "RESIZING")

def __eq__(self, other) -> bool:
return self.state == other.state and self.state_message == other.state_message

def __repr__(self) -> str:
return str(self.__dict__)

def to_json(self) -> str:
return json.dumps(self.__dict__)

@classmethod
def from_json(cls, data: str) -> ClusterState:
return ClusterState(**json.loads(data))


class DatabricksHook(BaseDatabricksHook):
"""
Interact with Databricks.
Expand Down Expand Up @@ -474,6 +525,32 @@ def repair_run(self, json: dict) -> None:
"""
self._do_api_call(REPAIR_RUN_ENDPOINT, json)

def get_cluster_state(self, cluster_id: str) -> ClusterState:
"""
Retrieves run state of the cluster.

:param cluster_id: id of the cluster
:return: state of the cluster
"""
json = {"cluster_id": cluster_id}
response = self._do_api_call(GET_CLUSTER_ENDPOINT, json)
state = response["state"]
state_message = response["state_message"]
return ClusterState(state, state_message)

async def a_get_cluster_state(self, cluster_id: str) -> ClusterState:
"""
Async version of `get_cluster_state`.

:param cluster_id: id of the cluster
:return: state of the cluster
"""
json = {"cluster_id": cluster_id}
response = await self._a_do_api_call(GET_CLUSTER_ENDPOINT, json)
state = response["state"]
state_message = response["state_message"]
return ClusterState(state, state_message)

def restart_cluster(self, json: dict) -> None:
"""
Restarts the cluster.
Expand Down
85 changes: 84 additions & 1 deletion tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.providers.databricks.hooks.databricks import (
GET_RUN_ENDPOINT,
SUBMIT_RUN_ENDPOINT,
ClusterState,
DatabricksHook,
RunState,
)
Expand Down Expand Up @@ -78,6 +79,9 @@
"state": {"life_cycle_state": LIFE_CYCLE_STATE, "state_message": STATE_MESSAGE},
}
GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}}
CLUSTER_STATE = "TERMINATED"
CLUSTER_STATE_MESSAGE = "Inactive cluster terminated (inactive for 120 minutes)."
GET_CLUSTER_RESPONSE = {"state": CLUSTER_STATE, "state_message": CLUSTER_STATE_MESSAGE}
NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}
JAR_PARAMS = ["param1", "param2"]
RESULT_STATE = ""
Expand Down Expand Up @@ -159,6 +163,13 @@ def repair_run_endpoint(host):
return f"https://{host}/api/2.1/jobs/runs/repair"


def get_cluster_endpoint(host):
"""
Utility function to generate the get run endpoint given the host.
"""
return f"https://{host}/api/2.0/clusters/get"
Seokyun-Ha marked this conversation as resolved.
Show resolved Hide resolved


def start_cluster_endpoint(host):
"""
Utility function to generate the get run endpoint given the host.
Expand Down Expand Up @@ -598,6 +609,26 @@ def test_repair_run(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_get_cluster_state(self, mock_requests):
"""
Response example from https://docs.databricks.com/api/workspace/clusters/get
"""
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = GET_CLUSTER_RESPONSE

cluster_state = self.hook.get_cluster_state(CLUSTER_ID)

assert cluster_state == ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
mock_requests.get.assert_called_once_with(
get_cluster_endpoint(HOST),
json=None,
params={"cluster_id": CLUSTER_ID},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_start_cluster(self, mock_requests):
mock_requests.codes.ok = 200
Expand Down Expand Up @@ -952,8 +983,8 @@ def test_is_terminal_false(self):
assert not run_state.is_terminal

def test_is_terminal_with_nonexistent_life_cycle_state(self):
run_state = RunState("blah", "", "")
with pytest.raises(AirflowException):
run_state = RunState("blah", "", "")
assert run_state.is_terminal

def test_is_successful(self):
Expand All @@ -973,6 +1004,41 @@ def test_from_json(self):
assert expected == RunState.from_json(json.dumps(state))


class TestClusterState:
def test_is_terminal_true(self):
terminal_states = ["TERMINATING", "TERMINATED", "ERROR", "UNKNOWN"]
for state in terminal_states:
cluster_state = ClusterState(state, "")
assert cluster_state.is_terminal

def test_is_terminal_false(self):
non_terminal_states = ["PENDING", "RUNNING", "RESTARTING", "RESIZING"]
for state in non_terminal_states:
cluster_state = ClusterState(state, "")
assert not cluster_state.is_terminal

def test_is_terminal_with_nonexistent_life_cycle_state(self):
with pytest.raises(AirflowException):
cluster_state = ClusterState("blah", "")
assert cluster_state.is_terminal

def test_is_running(self):
running_states = ["RUNNING", "RESIZING"]
for state in running_states:
cluster_state = ClusterState(state, "")
assert cluster_state.is_running

def test_to_json(self):
cluster_state = ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
expected = json.dumps(GET_CLUSTER_RESPONSE)
assert expected == cluster_state.to_json()

def test_from_json(self):
state = GET_CLUSTER_RESPONSE
expected = ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
assert expected == ClusterState.from_json(json.dumps(state))


def create_aad_token_for_resource(resource: str) -> dict:
return {
"token_type": "Bearer",
Expand Down Expand Up @@ -1284,6 +1350,23 @@ async def test_get_run_state(self, mock_get):
timeout=self.hook.timeout_seconds,
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_get_cluster_state(self, mock_get):
mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_CLUSTER_RESPONSE)

async with self.hook:
cluster_state = await self.hook.a_get_cluster_state(CLUSTER_ID)

assert cluster_state == ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
mock_get.assert_called_once_with(
get_cluster_endpoint(HOST),
json={"cluster_id": CLUSTER_ID},
auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)


class TestDatabricksHookAsyncAadToken:
"""
Expand Down