From 30aefeaa725c85abf580b1cc5fd37bcf05027e40 Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Sat, 20 Nov 2021 12:51:20 +0100 Subject: [PATCH 1/5] Databricks: add more methods to represent run state information this fixes #19357 --- .../providers/databricks/hooks/databricks.py | 48 +++++++++++++++++++ .../databricks/hooks/test_databricks.py | 24 ++++++++++ 2 files changed, 72 insertions(+) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index e26f51879067e..c230eac2d53d9 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -407,6 +407,14 @@ def get_run_state(self, run_id: str) -> RunState: """ Retrieves run state of the run. + Please note that any Airflow tasks that call the ``get_run_state`` method will result in + failure unless you have enabled xcom pickling. This can be done using the following + environment variable: ``AIRLFOW_CORE_ENABLE_XCOM_PICKLING=TRUE`` + + If you do not want to enable xcom pickling then use the ``get_run_state_str`` method to get + string describing state, or ``get_run_state_lifecycle``, ``get_run_state_result``, or + ``get_run_state_message`` to get individual components of the run state. + :param run_id: id of the run :return: state of the run """ @@ -419,6 +427,46 @@ def get_run_state(self, run_id: str) -> RunState: state_message = state['state_message'] return RunState(life_cycle_state, result_state, state_message) + def get_run_state_str(self, run_id: str) -> str: + """ + Returns string representation of RunState + + :param run_id: id of the run + :return: string describing run state + """ + state = self.get_run_state(run_id) + run_state_str = ( + f"State: {state.life_cycle_state}. Result: {state.result_state}. {state.state_message}" + ) + return run_state_str + + def get_run_state_lifecycle(self, run_id: str) -> str: + """ + Returns lifecycle state of the run + + :param run_id: id of the run + :return: string with lifecycle state + """ + return self.get_run_state(run_id).life_cycle_state + + def get_run_state_result(self, run_id: str) -> str: + """ + Returns resulting state of the run + + :param run_id: id of the run + :return: string with resulting state + """ + return self.get_run_state(run_id).result_state + + def get_run_state_message(self, run_id: str) -> str: + """ + Returns state message for the run + + :param run_id: id of the run + :return: string with state message + """ + return self.get_run_state(run_id).state_message + def cancel_run(self, run_id: str) -> None: """ Cancels the run. diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 13430c9438997..a5f0fb467fe64 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -389,6 +389,30 @@ def test_get_run_state(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch('airflow.providers.databricks.hooks.databricks.requests') + def test_get_run_state_str(self, mock_requests): + mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE + run_state_str = self.hook.get_run_state_str(RUN_ID) + assert run_state_str == f"State: {LIFE_CYCLE_STATE}. Result: {RESULT_STATE}. {STATE_MESSAGE}" + + @mock.patch('airflow.providers.databricks.hooks.databricks.requests') + def test_get_run_state_lifecycle(self, mock_requests): + mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE + lifecycle_state = self.hook.get_run_state_lifecycle(RUN_ID) + assert lifecycle_state == LIFE_CYCLE_STATE + + @mock.patch('airflow.providers.databricks.hooks.databricks.requests') + def test_get_run_state_result(self, mock_requests): + mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE + result_state = self.hook.get_run_state_result(RUN_ID) + assert result_state == RESULT_STATE + + @mock.patch('airflow.providers.databricks.hooks.databricks.requests') + def test_get_run_state_cycle(self, mock_requests): + mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE + state_message = self.hook.get_run_state_message(RUN_ID) + assert state_message == STATE_MESSAGE + @mock.patch('airflow.providers.databricks.hooks.databricks.requests') def test_cancel_run(self, mock_requests): mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE From 36d888b0d0794ff54dd8e2ccdfac260194a19708 Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Tue, 23 Nov 2021 12:23:15 +0100 Subject: [PATCH 2/5] Update airflow/providers/databricks/hooks/databricks.py Co-authored-by: Tzu-ping Chung --- airflow/providers/databricks/hooks/databricks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index c230eac2d53d9..c81e9033d9262 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -411,8 +411,8 @@ def get_run_state(self, run_id: str) -> RunState: failure unless you have enabled xcom pickling. This can be done using the following environment variable: ``AIRLFOW_CORE_ENABLE_XCOM_PICKLING=TRUE`` - If you do not want to enable xcom pickling then use the ``get_run_state_str`` method to get - string describing state, or ``get_run_state_lifecycle``, ``get_run_state_result``, or + If you do not want to enable xcom pickling, use the ``get_run_state_str`` method to get + a string describing state, or ``get_run_state_lifecycle``, ``get_run_state_result``, or ``get_run_state_message`` to get individual components of the run state. :param run_id: id of the run From 46ce97a3efb1f5a0290df378099323728444ad32 Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Tue, 23 Nov 2021 12:23:19 +0100 Subject: [PATCH 3/5] Update airflow/providers/databricks/hooks/databricks.py Co-authored-by: Tzu-ping Chung --- airflow/providers/databricks/hooks/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index c81e9033d9262..6f101748ebb94 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -429,7 +429,7 @@ def get_run_state(self, run_id: str) -> RunState: def get_run_state_str(self, run_id: str) -> str: """ - Returns string representation of RunState + Return the string representation of RunState. :param run_id: id of the run :return: string describing run state From 582db5fa7176800f38ec223fc2a0afdcb05c3886 Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Tue, 23 Nov 2021 12:28:29 +0100 Subject: [PATCH 4/5] fix environment variable --- airflow/providers/databricks/hooks/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index c230eac2d53d9..06bcd52b07f51 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -409,7 +409,7 @@ def get_run_state(self, run_id: str) -> RunState: Please note that any Airflow tasks that call the ``get_run_state`` method will result in failure unless you have enabled xcom pickling. This can be done using the following - environment variable: ``AIRLFOW_CORE_ENABLE_XCOM_PICKLING=TRUE`` + environment variable: ``AIRFLOW__CORE__ENABLE_XCOM_PICKLING`` If you do not want to enable xcom pickling then use the ``get_run_state_str`` method to get string describing state, or ``get_run_state_lifecycle``, ``get_run_state_result``, or From 070834cc115f3ea43662db8f908a30e8dd381203 Mon Sep 17 00:00:00 2001 From: Alex Ott Date: Tue, 23 Nov 2021 12:33:07 +0100 Subject: [PATCH 5/5] fix grammar --- airflow/providers/databricks/hooks/databricks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 943adb88ca27e..a9da65e831510 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -442,7 +442,7 @@ def get_run_state_str(self, run_id: str) -> str: def get_run_state_lifecycle(self, run_id: str) -> str: """ - Returns lifecycle state of the run + Returns the lifecycle state of the run :param run_id: id of the run :return: string with lifecycle state @@ -451,7 +451,7 @@ def get_run_state_lifecycle(self, run_id: str) -> str: def get_run_state_result(self, run_id: str) -> str: """ - Returns resulting state of the run + Returns the resulting state of the run :param run_id: id of the run :return: string with resulting state @@ -460,7 +460,7 @@ def get_run_state_result(self, run_id: str) -> str: def get_run_state_message(self, run_id: str) -> str: """ - Returns state message for the run + Returns the state message for the run :param run_id: id of the run :return: string with state message