Skip to content

Commit

Permalink
Add DatabricksHook ClusterState (#34643)
Browse files Browse the repository at this point in the history
* implement ClusterState and get_cluster_state() method
  • Loading branch information
Seokyun-Ha authored Oct 12, 2023
1 parent 6ba2c44 commit 946b539
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 17 deletions.
109 changes: 93 additions & 16 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

GET_CLUSTER_ENDPOINT = ("GET", "api/2.0/clusters/get")
RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")
Expand All @@ -57,38 +58,39 @@

WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")

RUN_LIFE_CYCLE_STATES = [
"PENDING",
"RUNNING",
"TERMINATING",
"TERMINATED",
"SKIPPED",
"INTERNAL_ERROR",
"QUEUED",
]

SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")


class RunState:
"""Utility class for the run state concept of Databricks runs."""

RUN_LIFE_CYCLE_STATES = [
"PENDING",
"RUNNING",
"TERMINATING",
"TERMINATED",
"SKIPPED",
"INTERNAL_ERROR",
"QUEUED",
]

def __init__(
self, life_cycle_state: str, result_state: str = "", state_message: str = "", *args, **kwargs
) -> None:
if life_cycle_state not in self.RUN_LIFE_CYCLE_STATES:
raise AirflowException(
f"Unexpected life cycle state: {life_cycle_state}: If the state has "
"been introduced recently, please check the Databricks user "
"guide for troubleshooting information"
)

self.life_cycle_state = life_cycle_state
self.result_state = result_state
self.state_message = state_message

@property
def is_terminal(self) -> bool:
"""True if the current state is a terminal state."""
if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES:
raise AirflowException(
f"Unexpected life cycle state: {self.life_cycle_state}: If the state has "
"been introduced recently, please check the Databricks user "
"guide for troubleshooting information"
)
return self.life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR")

@property
Expand Down Expand Up @@ -116,6 +118,55 @@ def from_json(cls, data: str) -> RunState:
return RunState(**json.loads(data))


class ClusterState:
"""Utility class for the cluster state concept of Databricks cluster."""

CLUSTER_LIFE_CYCLE_STATES = [
"PENDING",
"RUNNING",
"RESTARTING",
"RESIZING",
"TERMINATING",
"TERMINATED",
"ERROR",
"UNKNOWN",
]

def __init__(self, state: str = "", state_message: str = "", *args, **kwargs) -> None:
if state not in self.CLUSTER_LIFE_CYCLE_STATES:
raise AirflowException(
f"Unexpected cluster life cycle state: {state}: If the state has "
"been introduced recently, please check the Databricks user "
"guide for troubleshooting information"
)

self.state = state
self.state_message = state_message

@property
def is_terminal(self) -> bool:
"""True if the current state is a terminal state."""
return self.state in ("TERMINATING", "TERMINATED", "ERROR", "UNKNOWN")

@property
def is_running(self) -> bool:
"""True if the current state is running."""
return self.state in ("RUNNING", "RESIZING")

def __eq__(self, other) -> bool:
return self.state == other.state and self.state_message == other.state_message

def __repr__(self) -> str:
return str(self.__dict__)

def to_json(self) -> str:
return json.dumps(self.__dict__)

@classmethod
def from_json(cls, data: str) -> ClusterState:
return ClusterState(**json.loads(data))


class DatabricksHook(BaseDatabricksHook):
"""
Interact with Databricks.
Expand Down Expand Up @@ -474,6 +525,32 @@ def repair_run(self, json: dict) -> None:
"""
self._do_api_call(REPAIR_RUN_ENDPOINT, json)

def get_cluster_state(self, cluster_id: str) -> ClusterState:
"""
Retrieves run state of the cluster.
:param cluster_id: id of the cluster
:return: state of the cluster
"""
json = {"cluster_id": cluster_id}
response = self._do_api_call(GET_CLUSTER_ENDPOINT, json)
state = response["state"]
state_message = response["state_message"]
return ClusterState(state, state_message)

async def a_get_cluster_state(self, cluster_id: str) -> ClusterState:
"""
Async version of `get_cluster_state`.
:param cluster_id: id of the cluster
:return: state of the cluster
"""
json = {"cluster_id": cluster_id}
response = await self._a_do_api_call(GET_CLUSTER_ENDPOINT, json)
state = response["state"]
state_message = response["state_message"]
return ClusterState(state, state_message)

def restart_cluster(self, json: dict) -> None:
"""
Restarts the cluster.
Expand Down
85 changes: 84 additions & 1 deletion tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.providers.databricks.hooks.databricks import (
GET_RUN_ENDPOINT,
SUBMIT_RUN_ENDPOINT,
ClusterState,
DatabricksHook,
RunState,
)
Expand Down Expand Up @@ -78,6 +79,9 @@
"state": {"life_cycle_state": LIFE_CYCLE_STATE, "state_message": STATE_MESSAGE},
}
GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}}
CLUSTER_STATE = "TERMINATED"
CLUSTER_STATE_MESSAGE = "Inactive cluster terminated (inactive for 120 minutes)."
GET_CLUSTER_RESPONSE = {"state": CLUSTER_STATE, "state_message": CLUSTER_STATE_MESSAGE}
NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"}
JAR_PARAMS = ["param1", "param2"]
RESULT_STATE = ""
Expand Down Expand Up @@ -159,6 +163,13 @@ def repair_run_endpoint(host):
return f"https://{host}/api/2.1/jobs/runs/repair"


def get_cluster_endpoint(host):
"""
Utility function to generate the get run endpoint given the host.
"""
return f"https://{host}/api/2.0/clusters/get"


def start_cluster_endpoint(host):
"""
Utility function to generate the get run endpoint given the host.
Expand Down Expand Up @@ -598,6 +609,26 @@ def test_repair_run(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_get_cluster_state(self, mock_requests):
"""
Response example from https://docs.databricks.com/api/workspace/clusters/get
"""
mock_requests.codes.ok = 200
mock_requests.get.return_value.json.return_value = GET_CLUSTER_RESPONSE

cluster_state = self.hook.get_cluster_state(CLUSTER_ID)

assert cluster_state == ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
mock_requests.get.assert_called_once_with(
get_cluster_endpoint(HOST),
json=None,
params={"cluster_id": CLUSTER_ID},
auth=HTTPBasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_start_cluster(self, mock_requests):
mock_requests.codes.ok = 200
Expand Down Expand Up @@ -952,8 +983,8 @@ def test_is_terminal_false(self):
assert not run_state.is_terminal

def test_is_terminal_with_nonexistent_life_cycle_state(self):
run_state = RunState("blah", "", "")
with pytest.raises(AirflowException):
run_state = RunState("blah", "", "")
assert run_state.is_terminal

def test_is_successful(self):
Expand All @@ -973,6 +1004,41 @@ def test_from_json(self):
assert expected == RunState.from_json(json.dumps(state))


class TestClusterState:
def test_is_terminal_true(self):
terminal_states = ["TERMINATING", "TERMINATED", "ERROR", "UNKNOWN"]
for state in terminal_states:
cluster_state = ClusterState(state, "")
assert cluster_state.is_terminal

def test_is_terminal_false(self):
non_terminal_states = ["PENDING", "RUNNING", "RESTARTING", "RESIZING"]
for state in non_terminal_states:
cluster_state = ClusterState(state, "")
assert not cluster_state.is_terminal

def test_is_terminal_with_nonexistent_life_cycle_state(self):
with pytest.raises(AirflowException):
cluster_state = ClusterState("blah", "")
assert cluster_state.is_terminal

def test_is_running(self):
running_states = ["RUNNING", "RESIZING"]
for state in running_states:
cluster_state = ClusterState(state, "")
assert cluster_state.is_running

def test_to_json(self):
cluster_state = ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
expected = json.dumps(GET_CLUSTER_RESPONSE)
assert expected == cluster_state.to_json()

def test_from_json(self):
state = GET_CLUSTER_RESPONSE
expected = ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
assert expected == ClusterState.from_json(json.dumps(state))


def create_aad_token_for_resource(resource: str) -> dict:
return {
"token_type": "Bearer",
Expand Down Expand Up @@ -1284,6 +1350,23 @@ async def test_get_run_state(self, mock_get):
timeout=self.hook.timeout_seconds,
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_get_cluster_state(self, mock_get):
mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_CLUSTER_RESPONSE)

async with self.hook:
cluster_state = await self.hook.a_get_cluster_state(CLUSTER_ID)

assert cluster_state == ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE)
mock_get.assert_called_once_with(
get_cluster_endpoint(HOST),
json={"cluster_id": CLUSTER_ID},
auth=aiohttp.BasicAuth(LOGIN, PASSWORD),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
)


class TestDatabricksHookAsyncAadToken:
"""
Expand Down

0 comments on commit 946b539

Please sign in to comment.