Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Databricks: add more methods to represent run state information #19723

Merged
merged 6 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,14 @@ def get_run_state(self, run_id: str) -> RunState:
"""
Retrieves run state of the run.

Please note that any Airflow tasks that call the ``get_run_state`` method will result in
failure unless you have enabled xcom pickling. This can be done using the following
environment variable: ``AIRFLOW__CORE__ENABLE_XCOM_PICKLING``

If you do not want to enable xcom pickling, use the ``get_run_state_str`` method to get
a string describing state, or ``get_run_state_lifecycle``, ``get_run_state_result``, or
``get_run_state_message`` to get individual components of the run state.

:param run_id: id of the run
:return: state of the run
"""
Expand All @@ -419,6 +427,46 @@ def get_run_state(self, run_id: str) -> RunState:
state_message = state['state_message']
return RunState(life_cycle_state, result_state, state_message)

def get_run_state_str(self, run_id: str) -> str:
"""
Return the string representation of RunState.

:param run_id: id of the run
:return: string describing run state
"""
state = self.get_run_state(run_id)
run_state_str = (
f"State: {state.life_cycle_state}. Result: {state.result_state}. {state.state_message}"
)
Comment on lines +438 to +440
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about putting this in RunState’s __str__ and just call str(state) here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s primarily because I’m not sure if it may break something inside airflow. I’m not expert in it (yet)

return run_state_str

def get_run_state_lifecycle(self, run_id: str) -> str:
"""
Returns the lifecycle state of the run

:param run_id: id of the run
:return: string with lifecycle state
"""
return self.get_run_state(run_id).life_cycle_state

def get_run_state_result(self, run_id: str) -> str:
"""
Returns the resulting state of the run

:param run_id: id of the run
:return: string with resulting state
"""
return self.get_run_state(run_id).result_state

def get_run_state_message(self, run_id: str) -> str:
"""
Returns the state message for the run

:param run_id: id of the run
:return: string with state message
"""
return self.get_run_state(run_id).state_message

def cancel_run(self, run_id: str) -> None:
"""
Cancels the run.
Expand Down
24 changes: 24 additions & 0 deletions tests/providers/databricks/hooks/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,30 @@ def test_get_run_state(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_get_run_state_str(self, mock_requests):
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
run_state_str = self.hook.get_run_state_str(RUN_ID)
assert run_state_str == f"State: {LIFE_CYCLE_STATE}. Result: {RESULT_STATE}. {STATE_MESSAGE}"

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_get_run_state_lifecycle(self, mock_requests):
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
lifecycle_state = self.hook.get_run_state_lifecycle(RUN_ID)
assert lifecycle_state == LIFE_CYCLE_STATE

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_get_run_state_result(self, mock_requests):
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
result_state = self.hook.get_run_state_result(RUN_ID)
assert result_state == RESULT_STATE

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_get_run_state_cycle(self, mock_requests):
mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
state_message = self.hook.get_run_state_message(RUN_ID)
assert state_message == STATE_MESSAGE

@mock.patch('airflow.providers.databricks.hooks.databricks.requests')
def test_cancel_run(self, mock_requests):
mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
Expand Down