diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 9eeaffdf64010..06ba762327178 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -36,6 +36,7 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook +GET_CLUSTER_ENDPOINT = ("GET", "api/2.0/clusters/get") RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart") START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start") TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete") @@ -57,25 +58,32 @@ WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status") -RUN_LIFE_CYCLE_STATES = [ - "PENDING", - "RUNNING", - "TERMINATING", - "TERMINATED", - "SKIPPED", - "INTERNAL_ERROR", - "QUEUED", -] - SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions") class RunState: """Utility class for the run state concept of Databricks runs.""" + RUN_LIFE_CYCLE_STATES = [ + "PENDING", + "RUNNING", + "TERMINATING", + "TERMINATED", + "SKIPPED", + "INTERNAL_ERROR", + "QUEUED", + ] + def __init__( self, life_cycle_state: str, result_state: str = "", state_message: str = "", *args, **kwargs ) -> None: + if life_cycle_state not in self.RUN_LIFE_CYCLE_STATES: + raise AirflowException( + f"Unexpected life cycle state: {life_cycle_state}: If the state has " + "been introduced recently, please check the Databricks user " + "guide for troubleshooting information" + ) + self.life_cycle_state = life_cycle_state self.result_state = result_state self.state_message = state_message @@ -83,12 +91,6 @@ def __init__( @property def is_terminal(self) -> bool: """True if the current state is a terminal state.""" - if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES: - raise AirflowException( - f"Unexpected life cycle state: {self.life_cycle_state}: If the state has " - "been introduced recently, please check the Databricks user " - "guide for troubleshooting information" - ) return self.life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR") @property @@ -116,6 +118,55 @@ def from_json(cls, data: str) -> RunState: return RunState(**json.loads(data)) +class ClusterState: + """Utility class for the cluster state concept of Databricks cluster.""" + + CLUSTER_LIFE_CYCLE_STATES = [ + "PENDING", + "RUNNING", + "RESTARTING", + "RESIZING", + "TERMINATING", + "TERMINATED", + "ERROR", + "UNKNOWN", + ] + + def __init__(self, state: str = "", state_message: str = "", *args, **kwargs) -> None: + if state not in self.CLUSTER_LIFE_CYCLE_STATES: + raise AirflowException( + f"Unexpected cluster life cycle state: {state}: If the state has " + "been introduced recently, please check the Databricks user " + "guide for troubleshooting information" + ) + + self.state = state + self.state_message = state_message + + @property + def is_terminal(self) -> bool: + """True if the current state is a terminal state.""" + return self.state in ("TERMINATING", "TERMINATED", "ERROR", "UNKNOWN") + + @property + def is_running(self) -> bool: + """True if the current state is running.""" + return self.state in ("RUNNING", "RESIZING") + + def __eq__(self, other) -> bool: + return self.state == other.state and self.state_message == other.state_message + + def __repr__(self) -> str: + return str(self.__dict__) + + def to_json(self) -> str: + return json.dumps(self.__dict__) + + @classmethod + def from_json(cls, data: str) -> ClusterState: + return ClusterState(**json.loads(data)) + + class DatabricksHook(BaseDatabricksHook): """ Interact with Databricks. @@ -474,6 +525,32 @@ def repair_run(self, json: dict) -> None: """ self._do_api_call(REPAIR_RUN_ENDPOINT, json) + def get_cluster_state(self, cluster_id: str) -> ClusterState: + """ + Retrieves run state of the cluster. + + :param cluster_id: id of the cluster + :return: state of the cluster + """ + json = {"cluster_id": cluster_id} + response = self._do_api_call(GET_CLUSTER_ENDPOINT, json) + state = response["state"] + state_message = response["state_message"] + return ClusterState(state, state_message) + + async def a_get_cluster_state(self, cluster_id: str) -> ClusterState: + """ + Async version of `get_cluster_state`. + + :param cluster_id: id of the cluster + :return: state of the cluster + """ + json = {"cluster_id": cluster_id} + response = await self._a_do_api_call(GET_CLUSTER_ENDPOINT, json) + state = response["state"] + state_message = response["state_message"] + return ClusterState(state, state_message) + def restart_cluster(self, json: dict) -> None: """ Restarts the cluster. diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 2b49fedebb2e7..2566e9a394769 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -34,6 +34,7 @@ from airflow.providers.databricks.hooks.databricks import ( GET_RUN_ENDPOINT, SUBMIT_RUN_ENDPOINT, + ClusterState, DatabricksHook, RunState, ) @@ -78,6 +79,9 @@ "state": {"life_cycle_state": LIFE_CYCLE_STATE, "state_message": STATE_MESSAGE}, } GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}} +CLUSTER_STATE = "TERMINATED" +CLUSTER_STATE_MESSAGE = "Inactive cluster terminated (inactive for 120 minutes)." +GET_CLUSTER_RESPONSE = {"state": CLUSTER_STATE, "state_message": CLUSTER_STATE_MESSAGE} NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider": "1457570074236"} JAR_PARAMS = ["param1", "param2"] RESULT_STATE = "" @@ -159,6 +163,13 @@ def repair_run_endpoint(host): return f"https://{host}/api/2.1/jobs/runs/repair" +def get_cluster_endpoint(host): + """ + Utility function to generate the get run endpoint given the host. + """ + return f"https://{host}/api/2.0/clusters/get" + + def start_cluster_endpoint(host): """ Utility function to generate the get run endpoint given the host. @@ -598,6 +609,26 @@ def test_repair_run(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_get_cluster_state(self, mock_requests): + """ + Response example from https://docs.databricks.com/api/workspace/clusters/get + """ + mock_requests.codes.ok = 200 + mock_requests.get.return_value.json.return_value = GET_CLUSTER_RESPONSE + + cluster_state = self.hook.get_cluster_state(CLUSTER_ID) + + assert cluster_state == ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE) + mock_requests.get.assert_called_once_with( + get_cluster_endpoint(HOST), + json=None, + params={"cluster_id": CLUSTER_ID}, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_start_cluster(self, mock_requests): mock_requests.codes.ok = 200 @@ -952,8 +983,8 @@ def test_is_terminal_false(self): assert not run_state.is_terminal def test_is_terminal_with_nonexistent_life_cycle_state(self): - run_state = RunState("blah", "", "") with pytest.raises(AirflowException): + run_state = RunState("blah", "", "") assert run_state.is_terminal def test_is_successful(self): @@ -973,6 +1004,41 @@ def test_from_json(self): assert expected == RunState.from_json(json.dumps(state)) +class TestClusterState: + def test_is_terminal_true(self): + terminal_states = ["TERMINATING", "TERMINATED", "ERROR", "UNKNOWN"] + for state in terminal_states: + cluster_state = ClusterState(state, "") + assert cluster_state.is_terminal + + def test_is_terminal_false(self): + non_terminal_states = ["PENDING", "RUNNING", "RESTARTING", "RESIZING"] + for state in non_terminal_states: + cluster_state = ClusterState(state, "") + assert not cluster_state.is_terminal + + def test_is_terminal_with_nonexistent_life_cycle_state(self): + with pytest.raises(AirflowException): + cluster_state = ClusterState("blah", "") + assert cluster_state.is_terminal + + def test_is_running(self): + running_states = ["RUNNING", "RESIZING"] + for state in running_states: + cluster_state = ClusterState(state, "") + assert cluster_state.is_running + + def test_to_json(self): + cluster_state = ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE) + expected = json.dumps(GET_CLUSTER_RESPONSE) + assert expected == cluster_state.to_json() + + def test_from_json(self): + state = GET_CLUSTER_RESPONSE + expected = ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE) + assert expected == ClusterState.from_json(json.dumps(state)) + + def create_aad_token_for_resource(resource: str) -> dict: return { "token_type": "Bearer", @@ -1284,6 +1350,23 @@ async def test_get_run_state(self, mock_get): timeout=self.hook.timeout_seconds, ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") + async def test_get_cluster_state(self, mock_get): + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_CLUSTER_RESPONSE) + + async with self.hook: + cluster_state = await self.hook.a_get_cluster_state(CLUSTER_ID) + + assert cluster_state == ClusterState(CLUSTER_STATE, CLUSTER_STATE_MESSAGE) + mock_get.assert_called_once_with( + get_cluster_endpoint(HOST), + json={"cluster_id": CLUSTER_ID}, + auth=aiohttp.BasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + class TestDatabricksHookAsyncAadToken: """